diff --git a/api/net/sf/briar/api/protocol/ProtocolReader.java b/api/net/sf/briar/api/protocol/ProtocolReader.java new file mode 100644 index 0000000000000000000000000000000000000000..1ed7866cd5b987edc2ac3dffc965ab78d4b756b1 --- /dev/null +++ b/api/net/sf/briar/api/protocol/ProtocolReader.java @@ -0,0 +1,24 @@ +package net.sf.briar.api.protocol; + +import java.io.IOException; + +public interface ProtocolReader { + + boolean hasAck() throws IOException; + Ack readAck() throws IOException; + + boolean hasBatch() throws IOException; + Batch readBatch() throws IOException; + + boolean hasOffer() throws IOException; + Offer readOffer() throws IOException; + + boolean hasRequest() throws IOException; + Request readRequest() throws IOException; + + boolean hasSubscriptionUpdate() throws IOException; + SubscriptionUpdate readSubscriptionUpdate() throws IOException; + + boolean hasTransportUpdate() throws IOException; + TransportUpdate readTransportUpdate() throws IOException; +} diff --git a/api/net/sf/briar/api/protocol/ProtocolReaderFactory.java b/api/net/sf/briar/api/protocol/ProtocolReaderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..2eb1a75187894ea9dffd3bb0e96c8221db9d5656 --- /dev/null +++ b/api/net/sf/briar/api/protocol/ProtocolReaderFactory.java @@ -0,0 +1,8 @@ +package net.sf.briar.api.protocol; + +import java.io.InputStream; + +public interface ProtocolReaderFactory { + + ProtocolReader createProtocolReader(InputStream in); +} diff --git a/api/net/sf/briar/api/transport/PacketReader.java b/api/net/sf/briar/api/transport/PacketReader.java index 95001e2116cccae6f8c36a34afa6dbab5fc1f91e..683f3cea087276516954fb4782ee2b555c056119 100644 --- a/api/net/sf/briar/api/transport/PacketReader.java +++ b/api/net/sf/briar/api/transport/PacketReader.java @@ -1,13 +1,8 @@ package net.sf.briar.api.transport; import java.io.IOException; - -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.Offer; -import net.sf.briar.api.protocol.Request; -import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.TransportUpdate; +import java.io.InputStream; +import java.security.GeneralSecurityException; /** * Reads encrypted packets from an underlying input stream, decrypts and @@ -15,21 +10,17 @@ import net.sf.briar.api.protocol.TransportUpdate; */ public interface PacketReader { - boolean hasAck() throws IOException; - Ack readAck() throws IOException; - - boolean hasBatch() throws IOException; - Batch readBatch() throws IOException; - - boolean hasOffer() throws IOException; - Offer readOffer() throws IOException; - - boolean hasRequest() throws IOException; - Request readRequest() throws IOException; - - boolean hasSubscriptionUpdate() throws IOException; - SubscriptionUpdate readSubscriptionUpdate() throws IOException; - - boolean hasTransportUpdate() throws IOException; - TransportUpdate readTransportUpdate() throws IOException; + /** + * Returns the input stream from which packets should be read. (Note that + * this is not the underlying input stream.) + */ + InputStream getInputStream(); + + /** + * Finishes reading the current packet (if any), authenticates the packet + * and prepares to read the next packet. If this method is called twice in + * succession without any intervening reads, the underlying input stream + * will be unaffected. + */ + void finishPacket() throws IOException, GeneralSecurityException; } diff --git a/api/net/sf/briar/api/transport/PacketWriter.java b/api/net/sf/briar/api/transport/PacketWriter.java index f88bd02eb4ff5bf9d2bba40aaf848ea3423283e6..bc4d92a0a682bdb1a0e1677fd5c9de9916ace839 100644 --- a/api/net/sf/briar/api/transport/PacketWriter.java +++ b/api/net/sf/briar/api/transport/PacketWriter.java @@ -20,5 +20,5 @@ public interface PacketWriter { * next packet. If this method is called twice in succession without any * intervening writes, the underlying output stream will be unaffected. */ - void nextPacket() throws IOException; + void finishPacket() throws IOException; } diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index dba8dbd0c351038eb270f1d288fd0bc136b764a8..6d36d8efb38cc18522aa7a1f9d5a7f83dcb95c53 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -12,6 +12,7 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; @@ -28,11 +29,12 @@ public class ProtocolModule extends AbstractModule { bind(AuthorFactory.class).to(AuthorFactoryImpl.class); bind(BatchFactory.class).to(BatchFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class); + bind(MessageEncoder.class).to(MessageEncoderImpl.class); bind(OfferFactory.class).to(OfferFactoryImpl.class); + bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class); bind(RequestFactory.class).to(RequestFactoryImpl.class); bind(SubscriptionFactory.class).to(SubscriptionFactoryImpl.class); bind(TransportFactory.class).to(TransportFactoryImpl.class); - bind(MessageEncoder.class).to(MessageEncoderImpl.class); } @Provides diff --git a/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java b/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..d1188805018375f4eb07d253f02d3a32485ec20a --- /dev/null +++ b/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java @@ -0,0 +1,51 @@ +package net.sf.briar.protocol; + +import java.io.InputStream; + +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.Batch; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.ProtocolReader; +import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.ReaderFactory; + +import com.google.inject.Inject; +import com.google.inject.Provider; + +class ProtocolReaderFactoryImpl implements ProtocolReaderFactory { + + private final ReaderFactory readerFactory; + private final Provider<ObjectReader<Ack>> ackProvider; + private final Provider<ObjectReader<Batch>> batchProvider; + private final Provider<ObjectReader<Offer>> offerProvider; + private final Provider<ObjectReader<Request>> requestProvider; + private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider; + private final Provider<ObjectReader<TransportUpdate>> transportProvider; + + @Inject + ProtocolReaderFactoryImpl(ReaderFactory readerFactory, + Provider<ObjectReader<Ack>> ackProvider, + Provider<ObjectReader<Batch>> batchProvider, + Provider<ObjectReader<Offer>> offerProvider, + Provider<ObjectReader<Request>> requestProvider, + Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider, + Provider<ObjectReader<TransportUpdate>> transportProvider) { + this.readerFactory = readerFactory; + this.ackProvider = ackProvider; + this.batchProvider = batchProvider; + this.offerProvider = offerProvider; + this.requestProvider = requestProvider; + this.subscriptionProvider = subscriptionProvider; + this.transportProvider = transportProvider; + } + + public ProtocolReader createProtocolReader(InputStream in) { + return new ProtocolReaderImpl(in, readerFactory, ackProvider.get(), + batchProvider.get(), offerProvider.get(), requestProvider.get(), + subscriptionProvider.get(), transportProvider.get()); + } +} diff --git a/components/net/sf/briar/protocol/ProtocolReaderImpl.java b/components/net/sf/briar/protocol/ProtocolReaderImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..5797d89b3d08b287340120fb6cb261fd0c98af05 --- /dev/null +++ b/components/net/sf/briar/protocol/ProtocolReaderImpl.java @@ -0,0 +1,85 @@ +package net.sf.briar.protocol; + +import java.io.IOException; +import java.io.InputStream; + +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.Batch; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.ProtocolReader; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.Tags; +import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.Reader; +import net.sf.briar.api.serial.ReaderFactory; + +class ProtocolReaderImpl implements ProtocolReader { + + private final Reader reader; + + ProtocolReaderImpl(InputStream in, ReaderFactory readerFactory, + ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader, + ObjectReader<Offer> offerReader, + ObjectReader<Request> requestReader, + ObjectReader<SubscriptionUpdate> subscriptionReader, + ObjectReader<TransportUpdate> transportReader) { + reader = readerFactory.createReader(in); + reader.addObjectReader(Tags.ACK, ackReader); + reader.addObjectReader(Tags.BATCH, batchReader); + reader.addObjectReader(Tags.OFFER, offerReader); + reader.addObjectReader(Tags.REQUEST, requestReader); + reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader); + reader.addObjectReader(Tags.TRANSPORTS, transportReader); + } + + public boolean hasAck() throws IOException { + return reader.hasUserDefined(Tags.ACK); + } + + public Ack readAck() throws IOException { + return reader.readUserDefined(Tags.ACK, Ack.class); + } + + public boolean hasBatch() throws IOException { + return reader.hasUserDefined(Tags.BATCH); + } + + public Batch readBatch() throws IOException { + return reader.readUserDefined(Tags.BATCH, Batch.class); + } + + public boolean hasOffer() throws IOException { + return reader.hasUserDefined(Tags.OFFER); + } + + public Offer readOffer() throws IOException { + return reader.readUserDefined(Tags.OFFER, Offer.class); + } + + public boolean hasRequest() throws IOException { + return reader.hasUserDefined(Tags.REQUEST); + } + + public Request readRequest() throws IOException { + return reader.readUserDefined(Tags.REQUEST, Request.class); + } + + public boolean hasSubscriptionUpdate() throws IOException { + return reader.hasUserDefined(Tags.SUBSCRIPTIONS); + } + + public SubscriptionUpdate readSubscriptionUpdate() throws IOException { + return reader.readUserDefined(Tags.SUBSCRIPTIONS, + SubscriptionUpdate.class); + } + + public boolean hasTransportUpdate() throws IOException { + return reader.hasUserDefined(Tags.TRANSPORTS); + } + + public TransportUpdate readTransportUpdate() throws IOException { + return reader.readUserDefined(Tags.TRANSPORTS, TransportUpdate.class); + } +} diff --git a/components/net/sf/briar/transport/PacketDecrypter.java b/components/net/sf/briar/transport/PacketDecrypter.java index c8587d3ae02e0115895cde3b20a91f17d1cbd37f..239426021613a7aa842bf46876642d05f9236f8b 100644 --- a/components/net/sf/briar/transport/PacketDecrypter.java +++ b/components/net/sf/briar/transport/PacketDecrypter.java @@ -8,6 +8,10 @@ interface PacketDecrypter { /** Returns the input stream from which packets should be read. */ InputStream getInputStream(); - /** Reads, decrypts and returns a tag from the underlying input stream. */ + /** + * Reads, decrypts and returns a tag from the underlying input stream. + * Returns null if the end of the input stream is reached before any bytes + * are read. + */ byte[] readTag() throws IOException; } diff --git a/components/net/sf/briar/transport/PacketDecrypterImpl.java b/components/net/sf/briar/transport/PacketDecrypterImpl.java index c0645af46221dbbbbee519d72e979473c0f0f29e..3fe5adb5c6840f92bf52db92aa53017b34f97463 100644 --- a/components/net/sf/briar/transport/PacketDecrypterImpl.java +++ b/components/net/sf/briar/transport/PacketDecrypterImpl.java @@ -54,9 +54,11 @@ class PacketDecrypterImpl extends FilterInputStream implements PacketDecrypter { bufOff = bufLen = 0; while(offset < tag.length) { int read = in.read(tag, offset, tag.length - offset); - if(read == -1) throw new EOFException(); + if(read == -1) break; offset += read; } + if(offset == 0) return null; // EOF between packets is acceptable + if(offset < tag.length) throw new EOFException(); betweenPackets = false; try { byte[] decryptedTag = tagCipher.doFinal(tag); diff --git a/components/net/sf/briar/transport/PacketReaderFactoryImpl.java b/components/net/sf/briar/transport/PacketReaderFactoryImpl.java index 119daa9cddc36fc7c8eddfd89816dd6317bcb0f1..5e5a7f2bff4bf5bbd3ae273d4a63e7ff4d6f2645 100644 --- a/components/net/sf/briar/transport/PacketReaderFactoryImpl.java +++ b/components/net/sf/briar/transport/PacketReaderFactoryImpl.java @@ -8,47 +8,18 @@ import javax.crypto.Mac; import javax.crypto.SecretKey; import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.Offer; -import net.sf.briar.api.protocol.Request; -import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.serial.ObjectReader; -import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.transport.PacketReader; import net.sf.briar.api.transport.PacketReaderFactory; import com.google.inject.Inject; -import com.google.inject.Provider; class PacketReaderFactoryImpl implements PacketReaderFactory { private final CryptoComponent crypto; - private final ReaderFactory readerFactory; - private final Provider<ObjectReader<Ack>> ackProvider; - private final Provider<ObjectReader<Batch>> batchProvider; - private final Provider<ObjectReader<Offer>> offerProvider; - private final Provider<ObjectReader<Request>> requestProvider; - private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider; - private final Provider<ObjectReader<TransportUpdate>> transportProvider; @Inject - PacketReaderFactoryImpl(CryptoComponent crypto, ReaderFactory readerFactory, - Provider<ObjectReader<Ack>> ackProvider, - Provider<ObjectReader<Batch>> batchProvider, - Provider<ObjectReader<Offer>> offerProvider, - Provider<ObjectReader<Request>> requestProvider, - Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider, - Provider<ObjectReader<TransportUpdate>> transportProvider) { + PacketReaderFactoryImpl(CryptoComponent crypto) { this.crypto = crypto; - this.readerFactory = readerFactory; - this.ackProvider = ackProvider; - this.batchProvider = batchProvider; - this.offerProvider = offerProvider; - this.requestProvider = requestProvider; - this.subscriptionProvider = subscriptionProvider; - this.transportProvider = transportProvider; } public PacketReader createPacketReader(byte[] firstTag, InputStream in, @@ -66,9 +37,6 @@ class PacketReaderFactoryImpl implements PacketReaderFactory { } PacketDecrypter decrypter = new PacketDecrypterImpl(firstTag, in, tagCipher, packetCipher, tagKey, packetKey); - return new PacketReaderImpl(firstTag, readerFactory, ackProvider.get(), - batchProvider.get(), offerProvider.get(), requestProvider.get(), - subscriptionProvider.get(), transportProvider.get(), - decrypter, mac, transportId, connection); + return new PacketReaderImpl(decrypter, mac, transportId, connection); } } diff --git a/components/net/sf/briar/transport/PacketReaderImpl.java b/components/net/sf/briar/transport/PacketReaderImpl.java index 363cc0a41a3b4b9e6dc1c976374e8d206832abfe..0e6372a0c6d8c4577629ed73bd11666e3c6c8f43 100644 --- a/components/net/sf/briar/transport/PacketReaderImpl.java +++ b/components/net/sf/briar/transport/PacketReaderImpl.java @@ -1,27 +1,18 @@ package net.sf.briar.transport; +import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.security.GeneralSecurityException; import java.util.Arrays; import javax.crypto.Mac; -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.Offer; -import net.sf.briar.api.protocol.Request; -import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.Tags; -import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.serial.FormatException; -import net.sf.briar.api.serial.ObjectReader; -import net.sf.briar.api.serial.Reader; -import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.transport.PacketReader; -class PacketReaderImpl implements PacketReader { +class PacketReaderImpl extends FilterInputStream implements PacketReader { - private final Reader reader; private final PacketDecrypter decrypter; private final Mac mac; private final int macLength, transportId; @@ -30,23 +21,9 @@ class PacketReaderImpl implements PacketReader { private long packet = 0L; private boolean betweenPackets = true; - PacketReaderImpl(byte[] firstTag, ReaderFactory readerFactory, - ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader, - ObjectReader<Offer> offerReader, - ObjectReader<Request> requestReader, - ObjectReader<SubscriptionUpdate> subscriptionReader, - ObjectReader<TransportUpdate> transportReader, - PacketDecrypter decrypter, Mac mac, int transportId, + PacketReaderImpl(PacketDecrypter decrypter, Mac mac, int transportId, long connection) { - InputStream in = decrypter.getInputStream(); - reader = readerFactory.createReader(in); - reader.addObjectReader(Tags.ACK, ackReader); - reader.addObjectReader(Tags.BATCH, batchReader); - reader.addObjectReader(Tags.OFFER, offerReader); - reader.addObjectReader(Tags.REQUEST, requestReader); - reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader); - reader.addObjectReader(Tags.TRANSPORTS, transportReader); - reader.addConsumer(new MacConsumer(mac)); + super(decrypter.getInputStream()); this.decrypter = decrypter; this.mac = mac; macLength = mac.getMacLength(); @@ -54,32 +31,36 @@ class PacketReaderImpl implements PacketReader { this.connection = connection; } - public boolean hasAck() throws IOException { + public InputStream getInputStream() { + return this; + } + + public void finishPacket() throws IOException, GeneralSecurityException { + if(!betweenPackets) readMac(); + } + + @Override + public int read() throws IOException { if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.ACK); + int i = in.read(); + if(i != -1) mac.update((byte) i); + return i; } - private void readTag() throws IOException { - assert betweenPackets; - if(packet > Constants.MAX_32_BIT_UNSIGNED) - throw new IllegalStateException(); - byte[] tag = decrypter.readTag(); - if(!TagDecoder.decodeTag(tag, transportId, connection, packet)) - throw new FormatException(); - mac.update(tag); - packet++; - betweenPackets = false; + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); } - public Ack readAck() throws IOException { + @Override + public int read(byte[] b, int off, int len) throws IOException { if(betweenPackets) readTag(); - Ack a = reader.readUserDefined(Tags.ACK, Ack.class); - readMac(); - betweenPackets = true; - return a; + int i = in.read(b, off, len); + if(i != -1) mac.update(b, off, i); + return i; } - private void readMac() throws IOException { + private void readMac() throws IOException, GeneralSecurityException { byte[] expectedMac = mac.doFinal(); byte[] actualMac = new byte[macLength]; InputStream in = decrypter.getInputStream(); @@ -89,74 +70,22 @@ class PacketReaderImpl implements PacketReader { if(read == -1) break; offset += read; } - if(offset < macLength) throw new FormatException(); - if(!Arrays.equals(expectedMac, actualMac)) throw new FormatException(); - } - - public boolean hasBatch() throws IOException { - if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.BATCH); - } - - public Batch readBatch() throws IOException { - if(betweenPackets) readTag(); - Batch b = reader.readUserDefined(Tags.BATCH, Batch.class); - readMac(); - betweenPackets = true; - return b; - } - - public boolean hasOffer() throws IOException { - if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.OFFER); - } - - public Offer readOffer() throws IOException { - if(betweenPackets) readTag(); - Offer o = reader.readUserDefined(Tags.OFFER, Offer.class); - readMac(); + if(offset < macLength) throw new GeneralSecurityException(); + if(!Arrays.equals(expectedMac, actualMac)) + throw new GeneralSecurityException(); betweenPackets = true; - return o; } - public boolean hasRequest() throws IOException { - if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.REQUEST); - } - - public Request readRequest() throws IOException { - if(betweenPackets) readTag(); - Request r = reader.readUserDefined(Tags.REQUEST, Request.class); - readMac(); - betweenPackets = true; - return r; - } - - public boolean hasSubscriptionUpdate() throws IOException { - if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.SUBSCRIPTIONS); - } - - public SubscriptionUpdate readSubscriptionUpdate() throws IOException { - if(betweenPackets) readTag(); - SubscriptionUpdate s = reader.readUserDefined(Tags.SUBSCRIPTIONS, - SubscriptionUpdate.class); - readMac(); - betweenPackets = true; - return s; - } - - public boolean hasTransportUpdate() throws IOException { - if(betweenPackets) readTag(); - return reader.hasUserDefined(Tags.TRANSPORTS); - } - - public TransportUpdate readTransportUpdate() throws IOException { - if(betweenPackets) readTag(); - TransportUpdate t = reader.readUserDefined(Tags.TRANSPORTS, - TransportUpdate.class); - readMac(); - betweenPackets = true; - return t; + private void readTag() throws IOException { + assert betweenPackets; + if(packet > Constants.MAX_32_BIT_UNSIGNED) + throw new IllegalStateException(); + byte[] tag = decrypter.readTag(); + if(tag == null) return; // EOF + if(!TagDecoder.decodeTag(tag, transportId, connection, packet)) + throw new FormatException(); + mac.update(tag); + packet++; + betweenPackets = false; } } diff --git a/components/net/sf/briar/transport/PacketWriterImpl.java b/components/net/sf/briar/transport/PacketWriterImpl.java index dd78de97f656ed582a85974fbf5323879e51274f..943e3d95f6e4a17dd4c929713596b0b71e05fdf4 100644 --- a/components/net/sf/briar/transport/PacketWriterImpl.java +++ b/components/net/sf/briar/transport/PacketWriterImpl.java @@ -37,7 +37,7 @@ class PacketWriterImpl extends FilterOutputStream implements PacketWriter { return this; } - public void nextPacket() throws IOException { + public void finishPacket() throws IOException { if(!betweenPackets) writeMac(); } @@ -50,9 +50,7 @@ class PacketWriterImpl extends FilterOutputStream implements PacketWriter { @Override public void write(byte[] b) throws IOException { - if(betweenPackets) writeTag(); - out.write(b); - mac.update(b); + write(b, 0, b.length); } @Override diff --git a/test/build.xml b/test/build.xml index e3eaf8b3cdb7d1602eca0e098c7d2466281f8fbe..1036f8f05f5de6be99ef9455073b9480bea1fc80 100644 --- a/test/build.xml +++ b/test/build.xml @@ -36,6 +36,8 @@ <test name='net.sf.briar.transport.ConnectionWindowImplTest'/> <test name='net.sf.briar.transport.PacketDecrypterImplTest'/> <test name='net.sf.briar.transport.PacketEncrypterImplTest'/> + <test name='net.sf.briar.transport.PacketReaderImplTest'/> + <test name='net.sf.briar.transport.PacketReadWriteTest'/> <test name='net.sf.briar.transport.PacketWriterImplTest'/> <test name='net.sf.briar.transport.TagDecoderTest'/> <test name='net.sf.briar.transport.TagEncoderTest'/> diff --git a/test/net/sf/briar/FileReadWriteTest.java b/test/net/sf/briar/FileReadWriteTest.java index f04f8375ff2631d709ff4eaf6bfc3427b6cc5fdd..490c0d08b0c93680d030254b80c2186324725708 100644 --- a/test/net/sf/briar/FileReadWriteTest.java +++ b/test/net/sf/briar/FileReadWriteTest.java @@ -3,6 +3,7 @@ package net.sf.briar; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.InputStream; import java.io.OutputStream; import java.security.KeyPair; import java.util.Arrays; @@ -26,6 +27,8 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.ProtocolReader; +import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; @@ -64,6 +67,7 @@ public class FileReadWriteTest extends TestCase { private final PacketReaderFactory packetReaderFactory; private final PacketWriterFactory packetWriterFactory; + private final ProtocolReaderFactory protocolReaderFactory; private final ProtocolWriterFactory protocolWriterFactory; private final CryptoComponent crypto; private final byte[] secret = new byte[45]; @@ -83,6 +87,7 @@ public class FileReadWriteTest extends TestCase { new WritersModule()); packetReaderFactory = i.getInstance(PacketReaderFactory.class); packetWriterFactory = i.getInstance(PacketWriterFactory.class); + protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); crypto = i.getInstance(CryptoComponent.class); assertEquals(crypto.getMessageDigest().getDigestLength(), @@ -121,14 +126,14 @@ public class FileReadWriteTest extends TestCase { @Test public void testWriteFile() throws Exception { OutputStream out = new FileOutputStream(file); - PacketWriter p = packetWriterFactory.createPacketWriter(out, + PacketWriter packetWriter = packetWriterFactory.createPacketWriter(out, transportId, connection, secret); - out = p.getOutputStream(); + out = packetWriter.getOutputStream(); AckWriter a = protocolWriterFactory.createAckWriter(out); assertTrue(a.writeBatchId(ack)); a.finish(); - p.nextPacket(); + packetWriter.finishPacket(); BatchWriter b = protocolWriterFactory.createBatchWriter(out); assertTrue(b.writeMessage(message.getBytes())); @@ -136,7 +141,7 @@ public class FileReadWriteTest extends TestCase { assertTrue(b.writeMessage(message2.getBytes())); assertTrue(b.writeMessage(message3.getBytes())); b.finish(); - p.nextPacket(); + packetWriter.finishPacket(); OfferWriter o = protocolWriterFactory.createOfferWriter(out); assertTrue(o.writeMessageId(message.getId())); @@ -144,14 +149,14 @@ public class FileReadWriteTest extends TestCase { assertTrue(o.writeMessageId(message2.getId())); assertTrue(o.writeMessageId(message3.getId())); o.finish(); - p.nextPacket(); + packetWriter.finishPacket(); RequestWriter r = protocolWriterFactory.createRequestWriter(out); BitSet requested = new BitSet(4); requested.set(1); requested.set(3); r.writeBitmap(requested, 4); - p.nextPacket(); + packetWriter.finishPacket(); SubscriptionWriter s = protocolWriterFactory.createSubscriptionWriter(out); @@ -160,11 +165,11 @@ public class FileReadWriteTest extends TestCase { subs.put(group, 0L); subs.put(group1, 0L); s.writeSubscriptions(subs); - p.nextPacket(); + packetWriter.finishPacket(); TransportWriter t = protocolWriterFactory.createTransportWriter(out); t.writeTransports(transports); - p.nextPacket(); + packetWriter.finishPacket(); out.flush(); out.close(); @@ -177,7 +182,7 @@ public class FileReadWriteTest extends TestCase { testWriteFile(); - FileInputStream in = new FileInputStream(file); + InputStream in = new FileInputStream(file); byte[] firstTag = new byte[16]; int offset = 0; while(offset < 16) { @@ -186,17 +191,22 @@ public class FileReadWriteTest extends TestCase { offset += read; } assertEquals(16, offset); - PacketReader p = packetReaderFactory.createPacketReader(firstTag, in, - transportId, connection, secret); + PacketReader packetReader = packetReaderFactory.createPacketReader( + firstTag, in, transportId, connection, secret); + in = packetReader.getInputStream(); + ProtocolReader protocolReader = + protocolReaderFactory.createProtocolReader(in); // Read the ack - assertTrue(p.hasAck()); - Ack a = p.readAck(); + assertTrue(protocolReader.hasAck()); + Ack a = protocolReader.readAck(); + packetReader.finishPacket(); assertEquals(Collections.singletonList(ack), a.getBatchIds()); // Read the batch - assertTrue(p.hasBatch()); - Batch b = p.readBatch(); + assertTrue(protocolReader.hasBatch()); + Batch b = protocolReader.readBatch(); + packetReader.finishPacket(); Collection<Message> messages = b.getMessages(); assertEquals(4, messages.size()); Iterator<Message> it = messages.iterator(); @@ -206,8 +216,9 @@ public class FileReadWriteTest extends TestCase { checkMessageEquality(message3, it.next()); // Read the offer - assertTrue(p.hasOffer()); - Offer o = p.readOffer(); + assertTrue(protocolReader.hasOffer()); + Offer o = protocolReader.readOffer(); + packetReader.finishPacket(); Collection<MessageId> offered = o.getMessageIds(); assertEquals(4, offered.size()); Iterator<MessageId> it1 = offered.iterator(); @@ -217,8 +228,9 @@ public class FileReadWriteTest extends TestCase { assertEquals(message3.getId(), it1.next()); // Read the request - assertTrue(p.hasRequest()); - Request r = p.readRequest(); + assertTrue(protocolReader.hasRequest()); + Request r = protocolReader.readRequest(); + packetReader.finishPacket(); BitSet requested = r.getBitmap(); assertFalse(requested.get(0)); assertTrue(requested.get(1)); @@ -228,8 +240,9 @@ public class FileReadWriteTest extends TestCase { assertEquals(2, requested.cardinality()); // Read the subscription update - assertTrue(p.hasSubscriptionUpdate()); - SubscriptionUpdate s = p.readSubscriptionUpdate(); + assertTrue(protocolReader.hasSubscriptionUpdate()); + SubscriptionUpdate s = protocolReader.readSubscriptionUpdate(); + packetReader.finishPacket(); Map<Group, Long> subs = s.getSubscriptions(); assertEquals(2, subs.size()); assertEquals(Long.valueOf(0L), subs.get(group)); @@ -238,11 +251,14 @@ public class FileReadWriteTest extends TestCase { assertTrue(s.getTimestamp() <= System.currentTimeMillis()); // Read the transport update - assertTrue(p.hasTransportUpdate()); - TransportUpdate t = p.readTransportUpdate(); + assertTrue(protocolReader.hasTransportUpdate()); + TransportUpdate t = protocolReader.readTransportUpdate(); + packetReader.finishPacket(); assertEquals(transports, t.getTransports()); assertTrue(t.getTimestamp() > start); assertTrue(t.getTimestamp() <= System.currentTimeMillis()); + + in.close(); } @After diff --git a/test/net/sf/briar/transport/PacketReadWriteTest.java b/test/net/sf/briar/transport/PacketReadWriteTest.java new file mode 100644 index 0000000000000000000000000000000000000000..a4f47e1bfa1a5df8ac645d00d26f1de1c6f9827d --- /dev/null +++ b/test/net/sf/briar/transport/PacketReadWriteTest.java @@ -0,0 +1,98 @@ +package net.sf.briar.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Random; + +import javax.crypto.Cipher; +import javax.crypto.Mac; +import javax.crypto.SecretKey; + +import junit.framework.TestCase; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.transport.PacketReader; +import net.sf.briar.api.transport.PacketWriter; +import net.sf.briar.crypto.CryptoModule; + +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class PacketReadWriteTest extends TestCase { + + private final CryptoComponent crypto; + private final Cipher tagCipher, packetCipher; + private final SecretKey macKey, tagKey, packetKey; + private final Mac mac; + private final Random random; + private final byte[] secret = new byte[100]; + private final int transportId = 999; + private final long connection = 1234L; + + public PacketReadWriteTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + crypto = i.getInstance(CryptoComponent.class); + tagCipher = crypto.getTagCipher(); + packetCipher = crypto.getPacketCipher(); + macKey = crypto.deriveMacKey(secret); + tagKey = crypto.deriveTagKey(secret); + packetKey = crypto.derivePacketKey(secret); + mac = crypto.getMac(); + random = new Random(); + } + + @Test + public void testWriteAndRead() throws Exception { + // Generate two random packets + byte[] packet = new byte[12345]; + random.nextBytes(packet); + byte[] packet1 = new byte[321]; + random.nextBytes(packet1); + // Write the packets + ByteArrayOutputStream out = new ByteArrayOutputStream(); + PacketEncrypter encrypter = new PacketEncrypterImpl(out, tagCipher, + packetCipher, tagKey, packetKey); + mac.init(macKey); + PacketWriter writer = new PacketWriterImpl(encrypter, mac, transportId, + connection); + OutputStream out1 = writer.getOutputStream(); + out1.write(packet); + writer.finishPacket(); + out1.write(packet1); + writer.finishPacket(); + // Read the packets back + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + byte[] firstTag = new byte[Constants.TAG_BYTES]; + assertEquals(Constants.TAG_BYTES, in.read(firstTag)); + PacketDecrypter decrypter = new PacketDecrypterImpl(firstTag, in, + tagCipher, packetCipher, tagKey, packetKey); + PacketReader reader = new PacketReaderImpl(decrypter, mac, transportId, + connection); + InputStream in1 = reader.getInputStream(); + byte[] recovered = new byte[packet.length]; + int offset = 0; + while(offset < recovered.length) { + int read = in1.read(recovered, offset, recovered.length - offset); + if(read == -1) break; + offset += read; + } + assertEquals(recovered.length, offset); + reader.finishPacket(); + assertTrue(Arrays.equals(packet, recovered)); + byte[] recovered1 = new byte[packet1.length]; + offset = 0; + while(offset < recovered1.length) { + int read = in1.read(recovered1, offset, recovered1.length - offset); + if(read == -1) break; + offset += read; + } + assertEquals(recovered1.length, offset); + reader.finishPacket(); + assertTrue(Arrays.equals(packet1, recovered1)); + } +} diff --git a/test/net/sf/briar/transport/PacketReaderImplTest.java b/test/net/sf/briar/transport/PacketReaderImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..f2124401b3c4444aa168febf541d3ee8ab433c02 --- /dev/null +++ b/test/net/sf/briar/transport/PacketReaderImplTest.java @@ -0,0 +1,187 @@ +package net.sf.briar.transport; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.util.Arrays; + +import javax.crypto.Mac; + +import junit.framework.TestCase; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.transport.PacketReader; +import net.sf.briar.crypto.CryptoModule; +import net.sf.briar.util.StringUtils; + +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class PacketReaderImplTest extends TestCase { + + private final Mac mac; + + public PacketReaderImplTest() throws Exception { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + CryptoComponent crypto = i.getInstance(CryptoComponent.class); + mac = crypto.getMac(); + mac.init(crypto.generateSecretKey()); + } + + @Test + public void testFirstReadTriggersTag() throws Exception { + // TAG_BYTES for the tag, 1 byte for the packet + byte[] b = new byte[Constants.TAG_BYTES + 1]; + ByteArrayInputStream in = new ByteArrayInputStream(b); + PacketDecrypter d = new NullPacketDecrypter(in); + PacketReader p = new PacketReaderImpl(d, mac, 0, 0L); + // There should be one byte available before EOF + assertEquals(0, p.getInputStream().read()); + assertEquals(-1, p.getInputStream().read()); + } + + @Test + public void testFinishPacketAfterReadTriggersMac() throws Exception { + // TAG_BYTES for the tag, 1 byte for the packet + byte[] b = new byte[Constants.TAG_BYTES + 1]; + // Calculate the MAC and append it to the packet + mac.update(b); + byte[] macBytes = mac.doFinal(); + byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length); + System.arraycopy(macBytes, 0, b1, b.length, macBytes.length); + // Check that the PacketReader reads and verifies the MAC + ByteArrayInputStream in = new ByteArrayInputStream(b1); + PacketDecrypter d = new NullPacketDecrypter(in); + PacketReader p = new PacketReaderImpl(d, mac, 0, 0L); + assertEquals(0, p.getInputStream().read()); + p.finishPacket(); + // Reading the MAC should take us to EOF + assertEquals(-1, p.getInputStream().read()); + } + + @Test + public void testModifyingPacketInvalidatesMac() throws Exception { + // TAG_BYTES for the tag, 1 byte for the packet + byte[] b = new byte[Constants.TAG_BYTES + 1]; + // Calculate the MAC and append it to the packet + mac.update(b); + byte[] macBytes = mac.doFinal(); + byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length); + System.arraycopy(macBytes, 0, b1, b.length, macBytes.length); + // Modify the packet + b1[Constants.TAG_BYTES] = (byte) 1; + // Check that the PacketReader reads and fails to verify the MAC + ByteArrayInputStream in = new ByteArrayInputStream(b1); + PacketDecrypter d = new NullPacketDecrypter(in); + PacketReader p = new PacketReaderImpl(d, mac, 0, 0L); + assertEquals(1, p.getInputStream().read()); + try { + p.finishPacket(); + fail(); + } catch(GeneralSecurityException expected) {} + } + + @Test + public void testExtraCallsToFinishPacketDoNothing() throws Exception { + // TAG_BYTES for the tag, 1 byte for the packet + byte[] b = new byte[Constants.TAG_BYTES + 1]; + // Calculate the MAC and append it to the packet + mac.update(b); + byte[] macBytes = mac.doFinal(); + byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length); + System.arraycopy(macBytes, 0, b1, b.length, macBytes.length); + // Check that the PacketReader reads and verifies the MAC + ByteArrayInputStream in = new ByteArrayInputStream(b1); + PacketDecrypter d = new NullPacketDecrypter(in); + PacketReader p = new PacketReaderImpl(d, mac, 0, 0L); + // Initial calls to finishPacket() should have no effect + p.finishPacket(); + p.finishPacket(); + p.finishPacket(); + assertEquals(0, p.getInputStream().read()); + p.finishPacket(); + // Extra calls to finishPacket() should have no effect + p.finishPacket(); + p.finishPacket(); + p.finishPacket(); + // Reading the MAC should take us to EOF + assertEquals(-1, p.getInputStream().read()); + } + + @Test + public void testPacketNumberIsIncremented() throws Exception { + byte[] tag = StringUtils.fromHexString( + "0000" // 16 bits reserved + + "F00D" // 16 bits for the transport ID + + "DEADBEEF" // 32 bits for the connection number + + "00000000" // 32 bits for the packet number + + "00000000" // 32 bits for the block number + ); + assertEquals(Constants.TAG_BYTES, tag.length); + byte[] tag1 = StringUtils.fromHexString( + "0000" // 16 bits reserved + + "F00D" // 16 bits for the transport ID + + "DEADBEEF" // 32 bits for the connection number + + "00000001" // 32 bits for the packet number + + "00000000" // 32 bits for the block number + ); + assertEquals(Constants.TAG_BYTES, tag1.length); + // Calculate the MAC on the first packet and append it to the packet + mac.update(tag); + mac.update((byte) 0); + byte[] macBytes = mac.doFinal(); + byte[] b = Arrays.copyOf(tag, tag.length + 1 + macBytes.length); + System.arraycopy(macBytes, 0, b, tag.length + 1, macBytes.length); + // Calculate the MAC on the second packet and append it to the packet + mac.update(tag1); + mac.update((byte) 0); + byte[] macBytes1 = mac.doFinal(); + byte[] b1 = Arrays.copyOf(tag1, tag1.length + 1 + macBytes1.length); + System.arraycopy(macBytes1, 0, b1, tag.length + 1, macBytes1.length); + // Check that the PacketReader accepts the correct tags and MACs + byte[] b2 = Arrays.copyOf(b, b.length + b1.length); + System.arraycopy(b1, 0, b2, b.length, b1.length); + ByteArrayInputStream in = new ByteArrayInputStream(b2); + PacketDecrypter d = new NullPacketDecrypter(in); + PacketReader p = new PacketReaderImpl(d, mac, 0xF00D, 0xDEADBEEFL); + // Packet one + assertEquals(0, p.getInputStream().read()); + p.finishPacket(); + // Packet two + assertEquals(0, p.getInputStream().read()); + p.finishPacket(); + // We should be at EOF + assertEquals(-1, p.getInputStream().read()); + } + + /** A PacketDecrypter that performs no decryption. */ + private static class NullPacketDecrypter implements PacketDecrypter { + + private final InputStream in; + + private NullPacketDecrypter(InputStream in) { + this.in = in; + } + + public InputStream getInputStream() { + return in; + } + + public byte[] readTag() throws IOException { + byte[] tag = new byte[Constants.TAG_BYTES]; + int offset = 0; + while(offset < tag.length) { + int read = in.read(tag, offset, tag.length - offset); + if(read == -1) break; + offset += read; + } + if(offset == 0) return null; // EOF between packets is acceptable + if(offset < tag.length) throw new EOFException(); + return tag; + } + } +} diff --git a/test/net/sf/briar/transport/PacketWriterImplTest.java b/test/net/sf/briar/transport/PacketWriterImplTest.java index c79234b3647a2a6a7d732c5fc3522a175619bbd2..aae1a074497e366d91756ee26f7567329f8f5858 100644 --- a/test/net/sf/briar/transport/PacketWriterImplTest.java +++ b/test/net/sf/briar/transport/PacketWriterImplTest.java @@ -36,52 +36,56 @@ public class PacketWriterImplTest extends TestCase { PacketEncrypter e = new NullPacketEncrypter(out); PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L); p.getOutputStream().write(0); - // There should be TAG_BYTES bytes for the tag, 1 byte for the write + // There should be TAG_BYTES bytes for the tag, 1 byte for the packet assertTrue(Arrays.equals(new byte[Constants.TAG_BYTES + 1], out.toByteArray())); } @Test - public void testNextPacketAfterWriteTriggersMac() throws Exception { + public void testFinishPacketAfterWriteTriggersMac() throws Exception { // Calculate what the MAC should be - mac.update(new byte[17]); + mac.update(new byte[Constants.TAG_BYTES + 1]); byte[] expectedMac = mac.doFinal(); // Check that the PacketWriter calculates and writes the correct MAC ByteArrayOutputStream out = new ByteArrayOutputStream(); PacketEncrypter e = new NullPacketEncrypter(out); PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L); p.getOutputStream().write(0); - p.nextPacket(); + p.finishPacket(); byte[] written = out.toByteArray(); - assertEquals(17 + expectedMac.length, written.length); + assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length, + written.length); byte[] actualMac = new byte[expectedMac.length]; - System.arraycopy(written, 17, actualMac, 0, actualMac.length); + System.arraycopy(written, Constants.TAG_BYTES + 1, actualMac, 0, + actualMac.length); assertTrue(Arrays.equals(expectedMac, actualMac)); } @Test - public void testExtraCallsToNextPacketDoNothing() throws Exception { + public void testExtraCallsToFinishPacketDoNothing() throws Exception { // Calculate what the MAC should be - mac.update(new byte[17]); + mac.update(new byte[Constants.TAG_BYTES + 1]); byte[] expectedMac = mac.doFinal(); // Check that the PacketWriter calculates and writes the correct MAC ByteArrayOutputStream out = new ByteArrayOutputStream(); PacketEncrypter e = new NullPacketEncrypter(out); PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L); - // Initial calls to nextPacket() should have no effect - p.nextPacket(); - p.nextPacket(); - p.nextPacket(); + // Initial calls to finishPacket() should have no effect + p.finishPacket(); + p.finishPacket(); + p.finishPacket(); p.getOutputStream().write(0); - p.nextPacket(); - // Extra calls to nextPacket() should have no effect - p.nextPacket(); - p.nextPacket(); - p.nextPacket(); + p.finishPacket(); + // Extra calls to finishPacket() should have no effect + p.finishPacket(); + p.finishPacket(); + p.finishPacket(); byte[] written = out.toByteArray(); - assertEquals(17 + expectedMac.length, written.length); + assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length, + written.length); byte[] actualMac = new byte[expectedMac.length]; - System.arraycopy(written, 17, actualMac, 0, actualMac.length); + System.arraycopy(written, Constants.TAG_BYTES + 1, actualMac, 0, + actualMac.length); assertTrue(Arrays.equals(expectedMac, actualMac)); } @@ -117,10 +121,10 @@ public class PacketWriterImplTest extends TestCase { PacketWriter p = new PacketWriterImpl(e, mac, 0xF00D, 0xDEADBEEFL); // Packet one p.getOutputStream().write(0); - p.nextPacket(); + p.finishPacket(); // Packet two p.getOutputStream().write(0); - p.nextPacket(); + p.finishPacket(); byte[] written = out.toByteArray(); assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length + Constants.TAG_BYTES + 1 + expectedMac1.length, @@ -146,6 +150,7 @@ public class PacketWriterImplTest extends TestCase { assertTrue(Arrays.equals(expectedMac1, actualMac1)); } + /** A PacketEncrypter that performs no encryption. */ private static class NullPacketEncrypter implements PacketEncrypter { private final OutputStream out;