Skip to content
Snippets Groups Projects
Commit 7739bcdd authored by akwizgran's avatar akwizgran
Browse files

Second part of key rotation implementation. Work in progress.

parent 021b3c5a
No related branches found
No related tags found
No related merge requests found
package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
......@@ -14,5 +15,12 @@ public interface ConnectionRecogniser {
* expected, or null if the connection was not expected.
*/
ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException;
throws DbException;
void addWindow(ContactId c, TransportId t, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException;
void removeWindow(ContactId c, TransportId t, long period);
void removeWindows(ContactId c);
}
package net.sf.briar.api.transport;
import java.util.Set;
public interface ConnectionWindow {
boolean isSeen(long connection);
void setSeen(long connection);
Set<Long> getUnseen();
}
package net.sf.briar.transport;
import java.util.HashMap;
import java.util.Map;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import com.google.inject.Inject;
class ConnectionRecogniserImpl implements ConnectionRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
// Locking: this
private final Map<TransportId, TransportConnectionRecogniser> recognisers;
@Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) {
this.crypto = crypto;
this.db = db;
recognisers = new HashMap<TransportId, TransportConnectionRecogniser>();
}
public ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r == null) return null;
return r.acceptConnection(tag);
}
public void addWindow(ContactId c, TransportId t, long period,
boolean alice, byte[] secret, long centre, byte[] bitmap)
throws DbException {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
if(r == null) {
r = new TransportConnectionRecogniser(crypto, db, t);
recognisers.put(t, r);
}
}
r.addWindow(c, period, alice, secret, centre, bitmap);
}
public void removeWindow(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r != null) r.removeWindow(c, period);
}
public synchronized void removeWindows(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values()) {
r.removeWindows(c);
}
}
}
......@@ -3,59 +3,75 @@ package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import net.sf.briar.api.transport.ConnectionWindow;
// This class is not thread-safe
class ConnectionWindowImpl implements ConnectionWindow {
class ConnectionWindow {
private final Set<Long> unseen;
private long centre;
ConnectionWindowImpl() {
ConnectionWindow() {
unseen = new HashSet<Long>();
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l);
centre = 0;
}
ConnectionWindowImpl(Set<Long> unseen) {
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
for(long l : unseen) {
if(l < 0 || l > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(l < min) min = l;
if(l > max) max = l;
}
if(max - min > CONNECTION_WINDOW_SIZE)
ConnectionWindow(long centre, byte[] bitmap) {
if(centre < 0 || centre > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(bitmap.length != CONNECTION_WINDOW_SIZE / 8)
throw new IllegalArgumentException();
centre = max - CONNECTION_WINDOW_SIZE / 2 + 1;
for(long l = centre; l <= max; l++) {
if(!unseen.contains(l)) throw new IllegalArgumentException();
unseen = new HashSet<Long>();
long bottom = getBottom(centre);
long top = getTop(centre);
for(long l = bottom; l < top; l++) {
int offset = (int) (l - bottom);
int bytes = offset / 8;
int bits = offset % 8;
if((bitmap[bytes] & (128 >> bits)) == 0) unseen.add(l);
}
this.unseen = unseen;
this.centre = centre;
}
public boolean isSeen(long connection) {
boolean isSeen(long connection) {
return !unseen.contains(connection);
}
public void setSeen(long connection) {
Collection<Long> setSeen(long connection) {
long bottom = getBottom(centre);
long top = getTop(centre);
if(connection < bottom || connection > top)
throw new IllegalArgumentException();
if(!unseen.remove(connection))
throw new IllegalArgumentException();
Collection<Long> changed = new ArrayList<Long>();
if(connection >= centre) {
centre = connection + 1;
long newBottom = getBottom(centre);
long newTop = getTop(centre);
for(long l = bottom; l < newBottom; l++) unseen.remove(l);
for(long l = top + 1; l <= newTop; l++) unseen.add(l);
for(long l = bottom; l < newBottom; l++) {
if(unseen.remove(l)) changed.add(l);
}
for(long l = top + 1; l <= newTop; l++) {
if(unseen.add(l)) changed.add(l);
}
}
return changed;
}
long getCentre() {
return centre;
}
byte[] getBitmap() {
byte[] bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
// FIXME
return bitmap;
}
// Returns the lowest value contained in a window with the given centre
......@@ -69,7 +85,7 @@ class ConnectionWindowImpl implements ConnectionWindow {
centre + CONNECTION_WINDOW_SIZE / 2 - 1);
}
public Set<Long> getUnseen() {
public Collection<Long> getUnseen() {
return unseen;
}
}
package net.sf.briar.transport;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.util.ByteUtils;
class TagEncoder {
static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey,
long connection) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
// Blank plaintext
if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
for(int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
ByteUtils.writeUint32(connection, tag, 0);
try {
tagCipher.init(ENCRYPT_MODE, tagKey);
int encrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
......@@ -25,19 +29,4 @@ class TagEncoder {
throw new IllegalArgumentException(e);
}
}
static boolean decodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
try {
tagCipher.init(DECRYPT_MODE, tagKey);
int decrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
if(decrypted != TAG_LENGTH) throw new IllegalArgumentException();
//The plaintext should be blank
for(int i = 0; i < TAG_LENGTH; i++) if(tag[i] != 0) return false;
return true;
} catch(GeneralSecurityException e) {
// Unsuitable cipher or key
throw new IllegalArgumentException(e);
}
}
}
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import javax.crypto.Cipher;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils;
/** A connection recogniser for a specific transport. */
class TransportConnectionRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final TransportId transportId;
private final Map<Bytes, WindowContext> tagMap; // Locking: this
private final Map<RemovalKey, RemovalContext> windowMap; // Locking: this
TransportConnectionRecogniser(CryptoComponent crypto, DatabaseComponent db,
TransportId transportId) {
this.crypto = crypto;
this.db = db;
this.transportId = transportId;
tagMap = new HashMap<Bytes, WindowContext>();
windowMap = new HashMap<RemovalKey, RemovalContext>();
}
synchronized ConnectionContext acceptConnection(byte[] tag)
throws DbException {
WindowContext wctx = tagMap.remove(new Bytes(tag));
if(wctx == null) return null;
ConnectionWindow w = wctx.window;
ConnectionContext ctx = wctx.context;
long connection = ctx.getConnectionNumber();
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(ctx.getSecret(), ctx.getAlice());
byte[] changedTag = new byte[TAG_LENGTH];
Bytes changedTagWrapper = new Bytes(changedTag);
for(long conn : w.setSeen(connection)) {
TagEncoder.encodeTag(changedTag, cipher, key, conn);
WindowContext old;
if(conn <= connection) old = tagMap.remove(changedTagWrapper);
else old = tagMap.put(changedTagWrapper, wctx);
assert old == null;
}
key.erase();
ContactId c = ctx.getContactId();
long centre = w.getCentre();
byte[] bitmap = w.getBitmap();
db.setConnectionWindow(c, transportId, wctx.period, centre, bitmap);
return wctx.context;
}
synchronized void addWindow(ContactId c, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException {
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(secret, alice);
ConnectionWindow w = new ConnectionWindow(centre, bitmap);
for(long conn : w.getUnseen()) {
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, cipher, key, conn);
ConnectionContext ctx = new ConnectionContext(c, transportId, tag,
secret, conn, alice);
WindowContext wctx = new WindowContext(w, ctx, period);
tagMap.put(new Bytes(tag), wctx);
}
db.setConnectionWindow(c, transportId, period, centre, bitmap);
RemovalContext ctx = new RemovalContext(w, secret, alice);
windowMap.put(new RemovalKey(c, period), ctx);
}
synchronized void removeWindow(ContactId c, long period) {
RemovalContext ctx = windowMap.remove(new RemovalKey(c, period));
if(ctx == null) throw new IllegalArgumentException();
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(ctx.secret, ctx.alice);
byte[] removedTag = new byte[TAG_LENGTH];
Bytes removedTagWrapper = new Bytes(removedTag);
for(long conn : ctx.window.getUnseen()) {
TagEncoder.encodeTag(removedTag, cipher, key, conn);
WindowContext old = tagMap.remove(removedTagWrapper);
assert old != null;
}
key.erase();
ByteUtils.erase(ctx.secret);
}
synchronized void removeWindows(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : windowMap.keySet()) {
if(k.contactId.equals(c)) keysToRemove.add(k);
}
for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period);
}
private static class WindowContext {
private final ConnectionWindow window;
private final ConnectionContext context;
private final long period;
private WindowContext(ConnectionWindow window,
ConnectionContext context, long period) {
this.window = window;
this.context = context;
this.period = period;
}
}
private static class RemovalKey {
private final ContactId contactId;
private final long period;
private RemovalKey(ContactId contactId, long period) {
this.contactId = contactId;
this.period = period;
}
@Override
public int hashCode() {
return contactId.hashCode()+ (int) period;
}
@Override
public boolean equals(Object o) {
if(o instanceof RemovalKey) {
RemovalKey w = (RemovalKey) o;
return contactId.equals(w.contactId) && period == w.period;
}
return false;
}
}
private static class RemovalContext {
private final ConnectionWindow window;
private final byte[] secret;
private final boolean alice;
private RemovalContext(ConnectionWindow window, byte[] secret,
boolean alice) {
this.window = window;
this.secret = secret;
this.alice = alice;
}
}
}
......@@ -49,7 +49,7 @@
<test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.transport.ConnectionReaderImplTest'/>
<test name='net.sf.briar.transport.ConnectionRegistryImplTest'/>
<test name='net.sf.briar.transport.ConnectionWindowImplTest'/>
<test name='net.sf.briar.transport.ConnectionWindowTest'/>
<test name='net.sf.briar.transport.ConnectionWriterImplTest'/>
<test name='net.sf.briar.transport.IncomingEncryptionLayerTest'/>
<test name='net.sf.briar.transport.OutgoingEncryptionLayerTest'/>
......
package net.sf.briar.transport;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.HashSet;
import java.util.Set;
import java.util.Collection;
import net.sf.briar.BriarTestCase;
import net.sf.briar.api.transport.ConnectionWindow;
import org.junit.Test;
public class ConnectionWindowImplTest extends BriarTestCase {
public class ConnectionWindowTest extends BriarTestCase {
@Test
public void testWindowSliding() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
for(int i = 0; i < 100; i++) {
assertFalse(w.isSeen(i));
w.setSeen(i);
......@@ -24,7 +20,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testWindowJumping() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
for(int i = 0; i < 100; i += 13) {
assertFalse(w.isSeen(i));
w.setSeen(i);
......@@ -34,7 +30,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testWindowUpperLimit() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0, highest value in window is 15
w.setSeen(15);
// Centre is 16, highest value in window is 31
......@@ -44,20 +40,11 @@ public class ConnectionWindowImplTest extends BriarTestCase {
w.setSeen(48);
fail();
} catch(IllegalArgumentException expected) {}
// Values greater than 2^32 - 1 should never be allowed
Set<Long> unseen = new HashSet<Long>();
for(int i = 0; i < 32; i++) unseen.add(MAX_32_BIT_UNSIGNED - i);
w = new ConnectionWindowImpl(unseen);
w.setSeen(MAX_32_BIT_UNSIGNED);
try {
w.setSeen(MAX_32_BIT_UNSIGNED + 1);
fail();
} catch(IllegalArgumentException expected) {}
}
@Test
public void testWindowLowerLimit() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0, negative values should never be allowed
try {
w.setSeen(-1);
......@@ -87,7 +74,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testCannotSetSeenTwice() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
w.setSeen(15);
try {
w.setSeen(15);
......@@ -97,9 +84,9 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testGetUnseenConnectionNumbers() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0; window should cover 0 to 15, inclusive, with none seen
Set<Long> unseen = w.getUnseen();
Collection<Long> unseen = w.getUnseen();
assertEquals(16, unseen.size());
for(int i = 0; i < 16; i++) {
assertTrue(unseen.contains(Long.valueOf(i)));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment