diff --git a/api/net/sf/briar/api/protocol/writers/ProtocolReaderFactory.java b/api/net/sf/briar/api/protocol/writers/ProtocolReaderFactory.java deleted file mode 100644 index 5a0609c180a709f46386658dd771c718cb574ac1..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/ProtocolReaderFactory.java +++ /dev/null @@ -1,26 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.InputStream; - -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.Offer; -import net.sf.briar.api.protocol.Request; -import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.serial.ObjectReader; - -public interface ProtocolReaderFactory { - - ObjectReader<Ack> createAckReader(InputStream in); - - ObjectReader<Batch> createBatchReader(InputStream in); - - ObjectReader<Offer> createOfferReader(InputStream in); - - ObjectReader<Request> createRequestReader(InputStream in); - - ObjectReader<SubscriptionUpdate> createSubscriptionReader(InputStream in); - - ObjectReader<TransportUpdate> createTransportReader(InputStream in); -} diff --git a/api/net/sf/briar/api/transport/PacketReader.java b/api/net/sf/briar/api/transport/PacketReader.java index 0dde50d3279ed07afc358e9438c94adf25c9a4e4..95001e2116cccae6f8c36a34afa6dbab5fc1f91e 100644 --- a/api/net/sf/briar/api/transport/PacketReader.java +++ b/api/net/sf/briar/api/transport/PacketReader.java @@ -15,8 +15,6 @@ import net.sf.briar.api.protocol.TransportUpdate; */ public interface PacketReader { - boolean eof() throws IOException; - boolean hasAck() throws IOException; Ack readAck() throws IOException; diff --git a/api/net/sf/briar/api/transport/PacketReaderFactory.java b/api/net/sf/briar/api/transport/PacketReaderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..8e2ec5b3b651c03ac30de1c1c1ae2cbd3216de91 --- /dev/null +++ b/api/net/sf/briar/api/transport/PacketReaderFactory.java @@ -0,0 +1,9 @@ +package net.sf.briar.api.transport; + +import java.io.InputStream; + +public interface PacketReaderFactory { + + PacketReader createPacketReader(byte[] firstTag, InputStream in, + int transportId, long connection, byte[] secret); +} diff --git a/api/net/sf/briar/api/transport/PacketWriterFactory.java b/api/net/sf/briar/api/transport/PacketWriterFactory.java index f820576b0428826d4568885923a35c1eaf0092a8..c93e41e7567068d258d4a7b7ea0e97d6dafeea1c 100644 --- a/api/net/sf/briar/api/transport/PacketWriterFactory.java +++ b/api/net/sf/briar/api/transport/PacketWriterFactory.java @@ -2,11 +2,8 @@ package net.sf.briar.api.transport; import java.io.OutputStream; -import javax.crypto.SecretKey; - public interface PacketWriterFactory { PacketWriter createPacketWriter(OutputStream out, int transportId, - long connection, SecretKey macKey, SecretKey tagKey, - SecretKey packetKey); + long connection, byte[] secret); } diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index 57e6e2c6953721b64698e6c124eb0116cee76bd2..dba8dbd0c351038eb270f1d288fd0bc136b764a8 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -1,14 +1,20 @@ package net.sf.briar.protocol; import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.AuthorFactory; +import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.serial.ObjectReader; import com.google.inject.AbstractModule; @@ -29,12 +35,24 @@ public class ProtocolModule extends AbstractModule { bind(MessageEncoder.class).to(MessageEncoderImpl.class); } + @Provides + ObjectReader<Ack> getAckReader(ObjectReader<BatchId> batchIdReader, + AckFactory ackFactory) { + return new AckReader(batchIdReader, ackFactory); + } + @Provides ObjectReader<Author> getAuthorReader(CryptoComponent crypto, AuthorFactory authorFactory) { return new AuthorReader(crypto, authorFactory); } + @Provides + ObjectReader<Batch> getBatchReader(CryptoComponent crypto, + ObjectReader<Message> messageReader, BatchFactory batchFactory) { + return new BatchReader(crypto, messageReader, batchFactory); + } + @Provides ObjectReader<BatchId> getBatchIdReader() { return new BatchIdReader(); @@ -59,4 +77,28 @@ public class ProtocolModule extends AbstractModule { return new MessageReader(crypto, messageIdReader, groupReader, authorReader); } + + @Provides + ObjectReader<Offer> getOfferReader(ObjectReader<MessageId> messageIdReader, + OfferFactory offerFactory) { + return new OfferReader(messageIdReader, offerFactory); + } + + @Provides + ObjectReader<Request> getRequestReader(RequestFactory requestFactory) { + return new RequestReader(requestFactory); + } + + @Provides + ObjectReader<SubscriptionUpdate> getSubscriptionReader( + ObjectReader<Group> groupReader, + SubscriptionFactory subscriptionFactory) { + return new SubscriptionReader(groupReader, subscriptionFactory); + } + + @Provides + ObjectReader<TransportUpdate> getTransportReader( + TransportFactory transportFactory) { + return new TransportReader(transportFactory); + } } diff --git a/components/net/sf/briar/transport/PacketEncrypterImpl.java b/components/net/sf/briar/transport/PacketEncrypterImpl.java index 4469a0f4e0ab53fdca391038b1154a603463b2d2..69ca0c9ca3aa60d608f4caaa4ce1d0af22c3ab16 100644 --- a/components/net/sf/briar/transport/PacketEncrypterImpl.java +++ b/components/net/sf/briar/transport/PacketEncrypterImpl.java @@ -10,7 +10,6 @@ import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.SecretKey; -import javax.crypto.ShortBufferException; import javax.crypto.spec.IvParameterSpec; class PacketEncrypterImpl extends FilterOutputStream @@ -68,35 +67,19 @@ implements PacketEncrypter { @Override public void write(int b) throws IOException { - byte[] buf = new byte[] {(byte) b}; - try { - int i = packetCipher.update(buf, 0, buf.length, buf); - assert i <= 1; - if(i == 1) out.write(b); - } catch(ShortBufferException badCipher) { - throw new RuntimeException(badCipher); - } + byte[] ciphertext = packetCipher.update(new byte[] {(byte) b}); + if(ciphertext != null) out.write(ciphertext); } @Override public void write(byte[] b) throws IOException { - try { - int i = packetCipher.update(b, 0, b.length, b); - assert i <= b.length; - out.write(b, 0, i); - } catch(ShortBufferException badCipher) { - throw new RuntimeException(badCipher); - } + byte[] ciphertext = packetCipher.update(b); + if(ciphertext != null) out.write(ciphertext); } @Override public void write(byte[] b, int off, int len) throws IOException { - try { - int i = packetCipher.update(b, off, len, b, off); - assert i <= len; - out.write(b, off, i); - } catch(ShortBufferException badCipher) { - throw new RuntimeException(badCipher); - } + byte[] ciphertext = packetCipher.update(b, off, len); + if(ciphertext != null) out.write(ciphertext); } } diff --git a/components/net/sf/briar/transport/PacketReaderFactoryImpl.java b/components/net/sf/briar/transport/PacketReaderFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..119daa9cddc36fc7c8eddfd89816dd6317bcb0f1 --- /dev/null +++ b/components/net/sf/briar/transport/PacketReaderFactoryImpl.java @@ -0,0 +1,74 @@ +package net.sf.briar.transport; + +import java.io.InputStream; +import java.security.InvalidKeyException; + +import javax.crypto.Cipher; +import javax.crypto.Mac; +import javax.crypto.SecretKey; + +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.Batch; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.ReaderFactory; +import net.sf.briar.api.transport.PacketReader; +import net.sf.briar.api.transport.PacketReaderFactory; + +import com.google.inject.Inject; +import com.google.inject.Provider; + +class PacketReaderFactoryImpl implements PacketReaderFactory { + + private final CryptoComponent crypto; + private final ReaderFactory readerFactory; + private final Provider<ObjectReader<Ack>> ackProvider; + private final Provider<ObjectReader<Batch>> batchProvider; + private final Provider<ObjectReader<Offer>> offerProvider; + private final Provider<ObjectReader<Request>> requestProvider; + private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider; + private final Provider<ObjectReader<TransportUpdate>> transportProvider; + + @Inject + PacketReaderFactoryImpl(CryptoComponent crypto, ReaderFactory readerFactory, + Provider<ObjectReader<Ack>> ackProvider, + Provider<ObjectReader<Batch>> batchProvider, + Provider<ObjectReader<Offer>> offerProvider, + Provider<ObjectReader<Request>> requestProvider, + Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider, + Provider<ObjectReader<TransportUpdate>> transportProvider) { + this.crypto = crypto; + this.readerFactory = readerFactory; + this.ackProvider = ackProvider; + this.batchProvider = batchProvider; + this.offerProvider = offerProvider; + this.requestProvider = requestProvider; + this.subscriptionProvider = subscriptionProvider; + this.transportProvider = transportProvider; + } + + public PacketReader createPacketReader(byte[] firstTag, InputStream in, + int transportId, long connection, byte[] secret) { + SecretKey macKey = crypto.deriveMacKey(secret); + SecretKey tagKey = crypto.deriveTagKey(secret); + SecretKey packetKey = crypto.derivePacketKey(secret); + Cipher tagCipher = crypto.getTagCipher(); + Cipher packetCipher = crypto.getPacketCipher(); + Mac mac = crypto.getMac(); + try { + mac.init(macKey); + } catch(InvalidKeyException e) { + throw new IllegalArgumentException(e); + } + PacketDecrypter decrypter = new PacketDecrypterImpl(firstTag, in, + tagCipher, packetCipher, tagKey, packetKey); + return new PacketReaderImpl(firstTag, readerFactory, ackProvider.get(), + batchProvider.get(), offerProvider.get(), requestProvider.get(), + subscriptionProvider.get(), transportProvider.get(), + decrypter, mac, transportId, connection); + } +} diff --git a/components/net/sf/briar/transport/PacketReaderImpl.java b/components/net/sf/briar/transport/PacketReaderImpl.java index 7b654f1adc536c20c30bd64f0127d1ada9911ae8..363cc0a41a3b4b9e6dc1c976374e8d206832abfe 100644 --- a/components/net/sf/briar/transport/PacketReaderImpl.java +++ b/components/net/sf/briar/transport/PacketReaderImpl.java @@ -13,8 +13,8 @@ import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.ProtocolReaderFactory; import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.transport.PacketReader; @@ -31,18 +31,21 @@ class PacketReaderImpl implements PacketReader { private boolean betweenPackets = true; PacketReaderImpl(byte[] firstTag, ReaderFactory readerFactory, - ProtocolReaderFactory protocol, PacketDecrypter decrypter, Mac mac, - int transportId, long connection) { + ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader, + ObjectReader<Offer> offerReader, + ObjectReader<Request> requestReader, + ObjectReader<SubscriptionUpdate> subscriptionReader, + ObjectReader<TransportUpdate> transportReader, + PacketDecrypter decrypter, Mac mac, int transportId, + long connection) { InputStream in = decrypter.getInputStream(); reader = readerFactory.createReader(in); - reader.addObjectReader(Tags.ACK, protocol.createAckReader(in)); - reader.addObjectReader(Tags.BATCH, protocol.createBatchReader(in)); - reader.addObjectReader(Tags.OFFER, protocol.createOfferReader(in)); - reader.addObjectReader(Tags.REQUEST, protocol.createRequestReader(in)); - reader.addObjectReader(Tags.SUBSCRIPTIONS, - protocol.createSubscriptionReader(in)); - reader.addObjectReader(Tags.TRANSPORTS, - protocol.createTransportReader(in)); + reader.addObjectReader(Tags.ACK, ackReader); + reader.addObjectReader(Tags.BATCH, batchReader); + reader.addObjectReader(Tags.OFFER, offerReader); + reader.addObjectReader(Tags.REQUEST, requestReader); + reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader); + reader.addObjectReader(Tags.TRANSPORTS, transportReader); reader.addConsumer(new MacConsumer(mac)); this.decrypter = decrypter; this.mac = mac; @@ -51,10 +54,6 @@ class PacketReaderImpl implements PacketReader { this.connection = connection; } - public boolean eof() throws IOException { - return reader.eof(); - } - public boolean hasAck() throws IOException { if(betweenPackets) readTag(); return reader.hasUserDefined(Tags.ACK); diff --git a/components/net/sf/briar/transport/PacketWriterFactoryImpl.java b/components/net/sf/briar/transport/PacketWriterFactoryImpl.java index 2c4a48a025147cc19449f9e661c8f64a20a41d27..cf411c2e08ae46a3cddcf9e0e7979b50975654e0 100644 --- a/components/net/sf/briar/transport/PacketWriterFactoryImpl.java +++ b/components/net/sf/briar/transport/PacketWriterFactoryImpl.java @@ -3,6 +3,7 @@ package net.sf.briar.transport; import java.io.OutputStream; import java.security.InvalidKeyException; +import javax.crypto.Cipher; import javax.crypto.Mac; import javax.crypto.SecretKey; @@ -22,17 +23,20 @@ class PacketWriterFactoryImpl implements PacketWriterFactory { } public PacketWriter createPacketWriter(OutputStream out, int transportId, - long connection, SecretKey macKey, SecretKey tagKey, - SecretKey packetKey) { + long connection, byte[] secret) { + SecretKey macKey = crypto.deriveMacKey(secret); + SecretKey tagKey = crypto.deriveTagKey(secret); + SecretKey packetKey = crypto.derivePacketKey(secret); + Cipher tagCipher = crypto.getTagCipher(); + Cipher packetCipher = crypto.getPacketCipher(); Mac mac = crypto.getMac(); try { mac.init(macKey); } catch(InvalidKeyException e) { throw new IllegalArgumentException(e); } - PacketEncrypter e = new PacketEncrypterImpl(out, crypto.getTagCipher(), - crypto.getPacketCipher(), tagKey, packetKey); - return new PacketWriterImpl(e, mac, transportId, - connection); + PacketEncrypter encrypter = new PacketEncrypterImpl(out, tagCipher, + packetCipher, tagKey, packetKey); + return new PacketWriterImpl(encrypter, mac, transportId, connection); } } diff --git a/components/net/sf/briar/transport/PacketWriterImpl.java b/components/net/sf/briar/transport/PacketWriterImpl.java index 16c9ea8b79d30668006035d80e41dff97aa97165..dd78de97f656ed582a85974fbf5323879e51274f 100644 --- a/components/net/sf/briar/transport/PacketWriterImpl.java +++ b/components/net/sf/briar/transport/PacketWriterImpl.java @@ -56,10 +56,10 @@ class PacketWriterImpl extends FilterOutputStream implements PacketWriter { } @Override - public void write(byte[] b, int len, int off) throws IOException { + public void write(byte[] b, int off, int len) throws IOException { if(betweenPackets) writeTag(); - out.write(b, len, off); - mac.update(b, len, off); + out.write(b, off, len); + mac.update(b, off, len); } private void writeMac() throws IOException { diff --git a/components/net/sf/briar/transport/TransportModule.java b/components/net/sf/briar/transport/TransportModule.java index e1dd260a0c4273ff63933924b489ac60eb4f02dc..1fefd05eafab81a84782e5887fbb02fa9592cf03 100644 --- a/components/net/sf/briar/transport/TransportModule.java +++ b/components/net/sf/briar/transport/TransportModule.java @@ -1,6 +1,8 @@ package net.sf.briar.transport; import net.sf.briar.api.transport.ConnectionWindowFactory; +import net.sf.briar.api.transport.PacketReaderFactory; +import net.sf.briar.api.transport.PacketWriterFactory; import com.google.inject.AbstractModule; @@ -10,5 +12,7 @@ public class TransportModule extends AbstractModule { protected void configure() { bind(ConnectionWindowFactory.class).to( ConnectionWindowFactoryImpl.class); + bind(PacketReaderFactory.class).to(PacketReaderFactoryImpl.class); + bind(PacketWriterFactory.class).to(PacketWriterFactoryImpl.class); } } diff --git a/test/build.xml b/test/build.xml index 708e057603d5fd36c953730a1712945f11bc3f06..e3eaf8b3cdb7d1602eca0e098c7d2466281f8fbe 100644 --- a/test/build.xml +++ b/test/build.xml @@ -13,6 +13,7 @@ <path refid='test-classes'/> <path refid='util-classes'/> </classpath> + <test name='net.sf.briar.FileReadWriteTest'/> <test name='net.sf.briar.crypto.CounterModeTest'/> <test name='net.sf.briar.db.BasicH2Test'/> <test name='net.sf.briar.db.DatabaseCleanerImplTest'/> @@ -25,7 +26,6 @@ <test name='net.sf.briar.protocol.AckReaderTest'/> <test name='net.sf.briar.protocol.BatchReaderTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/> - <test name='net.sf.briar.protocol.FileReadWriteTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/> <test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/> <test name='net.sf.briar.protocol.writers.RequestWriterImplTest'/> diff --git a/test/net/sf/briar/protocol/FileReadWriteTest.java b/test/net/sf/briar/FileReadWriteTest.java similarity index 78% rename from test/net/sf/briar/protocol/FileReadWriteTest.java rename to test/net/sf/briar/FileReadWriteTest.java index 7762655319442d6b1cf3156869a028c3904c196e..f04f8375ff2631d709ff4eaf6bfc3427b6cc5fdd 100644 --- a/test/net/sf/briar/protocol/FileReadWriteTest.java +++ b/test/net/sf/briar/FileReadWriteTest.java @@ -1,19 +1,19 @@ -package net.sf.briar.protocol; +package net.sf.briar; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.OutputStream; import java.security.KeyPair; import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.Map; import junit.framework.TestCase; -import net.sf.briar.TestUtils; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Author; @@ -28,7 +28,6 @@ import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.writers.AckWriter; @@ -38,11 +37,15 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.protocol.writers.RequestWriter; import net.sf.briar.api.protocol.writers.SubscriptionWriter; import net.sf.briar.api.protocol.writers.TransportWriter; -import net.sf.briar.api.serial.Reader; -import net.sf.briar.api.serial.ReaderFactory; +import net.sf.briar.api.transport.PacketReader; +import net.sf.briar.api.transport.PacketReaderFactory; +import net.sf.briar.api.transport.PacketWriter; +import net.sf.briar.api.transport.PacketWriterFactory; import net.sf.briar.crypto.CryptoModule; +import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.writers.WritersModule; import net.sf.briar.serial.SerialModule; +import net.sf.briar.transport.TransportModule; import org.junit.After; import org.junit.Before; @@ -59,15 +62,13 @@ public class FileReadWriteTest extends TestCase { private final BatchId ack = new BatchId(TestUtils.getRandomId()); private final long start = System.currentTimeMillis(); - private final ReaderFactory readerFactory; + private final PacketReaderFactory packetReaderFactory; + private final PacketWriterFactory packetWriterFactory; private final ProtocolWriterFactory protocolWriterFactory; private final CryptoComponent crypto; - private final AckReader ackReader; - private final BatchReader batchReader; - private final OfferReader offerReader; - private final RequestReader requestReader; - private final SubscriptionReader subscriptionReader; - private final TransportReader transportReader; + private final byte[] secret = new byte[45]; + private final int transportId = 123; + private final long connection = 234L; private final Author author; private final Group group, group1; private final Message message, message1, message2, message3; @@ -78,19 +79,14 @@ public class FileReadWriteTest extends TestCase { public FileReadWriteTest() throws Exception { super(); Injector i = Guice.createInjector(new CryptoModule(), - new ProtocolModule(), new SerialModule(), + new ProtocolModule(), new SerialModule(), new TransportModule(), new WritersModule()); - readerFactory = i.getInstance(ReaderFactory.class); + packetReaderFactory = i.getInstance(PacketReaderFactory.class); + packetWriterFactory = i.getInstance(PacketWriterFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); crypto = i.getInstance(CryptoComponent.class); assertEquals(crypto.getMessageDigest().getDigestLength(), UniqueId.LENGTH); - ackReader = i.getInstance(AckReader.class); - batchReader = i.getInstance(BatchReader.class); - offerReader = i.getInstance(OfferReader.class); - requestReader = i.getInstance(RequestReader.class); - subscriptionReader = i.getInstance(SubscriptionReader.class); - transportReader = i.getInstance(TransportReader.class); // Create two groups: one restricted, one unrestricted GroupFactory groupFactory = i.getInstance(GroupFactory.class); group = groupFactory.createGroup("Unrestricted group", null); @@ -124,11 +120,15 @@ public class FileReadWriteTest extends TestCase { @Test public void testWriteFile() throws Exception { - FileOutputStream out = new FileOutputStream(file); + OutputStream out = new FileOutputStream(file); + PacketWriter p = packetWriterFactory.createPacketWriter(out, + transportId, connection, secret); + out = p.getOutputStream(); AckWriter a = protocolWriterFactory.createAckWriter(out); assertTrue(a.writeBatchId(ack)); a.finish(); + p.nextPacket(); BatchWriter b = protocolWriterFactory.createBatchWriter(out); assertTrue(b.writeMessage(message.getBytes())); @@ -136,6 +136,7 @@ public class FileReadWriteTest extends TestCase { assertTrue(b.writeMessage(message2.getBytes())); assertTrue(b.writeMessage(message3.getBytes())); b.finish(); + p.nextPacket(); OfferWriter o = protocolWriterFactory.createOfferWriter(out); assertTrue(o.writeMessageId(message.getId())); @@ -143,12 +144,14 @@ public class FileReadWriteTest extends TestCase { assertTrue(o.writeMessageId(message2.getId())); assertTrue(o.writeMessageId(message3.getId())); o.finish(); + p.nextPacket(); RequestWriter r = protocolWriterFactory.createRequestWriter(out); BitSet requested = new BitSet(4); requested.set(1); requested.set(3); r.writeBitmap(requested, 4); + p.nextPacket(); SubscriptionWriter s = protocolWriterFactory.createSubscriptionWriter(out); @@ -157,10 +160,13 @@ public class FileReadWriteTest extends TestCase { subs.put(group, 0L); subs.put(group1, 0L); s.writeSubscriptions(subs); + p.nextPacket(); TransportWriter t = protocolWriterFactory.createTransportWriter(out); t.writeTransports(transports); + p.nextPacket(); + out.flush(); out.close(); assertTrue(file.exists()); assertTrue(file.length() > message.getSize()); @@ -172,22 +178,25 @@ public class FileReadWriteTest extends TestCase { testWriteFile(); FileInputStream in = new FileInputStream(file); - Reader reader = readerFactory.createReader(in); - reader.addObjectReader(Tags.ACK, ackReader); - reader.addObjectReader(Tags.BATCH, batchReader); - reader.addObjectReader(Tags.OFFER, offerReader); - reader.addObjectReader(Tags.REQUEST, requestReader); - reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader); - reader.addObjectReader(Tags.TRANSPORTS, transportReader); + byte[] firstTag = new byte[16]; + int offset = 0; + while(offset < 16) { + int read = in.read(firstTag, offset, firstTag.length - offset); + if(read == -1) break; + offset += read; + } + assertEquals(16, offset); + PacketReader p = packetReaderFactory.createPacketReader(firstTag, in, + transportId, connection, secret); // Read the ack - assertTrue(reader.hasUserDefined(Tags.ACK)); - Ack a = reader.readUserDefined(Tags.ACK, Ack.class); + assertTrue(p.hasAck()); + Ack a = p.readAck(); assertEquals(Collections.singletonList(ack), a.getBatchIds()); // Read the batch - assertTrue(reader.hasUserDefined(Tags.BATCH)); - Batch b = reader.readUserDefined(Tags.BATCH, Batch.class); + assertTrue(p.hasBatch()); + Batch b = p.readBatch(); Collection<Message> messages = b.getMessages(); assertEquals(4, messages.size()); Iterator<Message> it = messages.iterator(); @@ -197,8 +206,8 @@ public class FileReadWriteTest extends TestCase { checkMessageEquality(message3, it.next()); // Read the offer - assertTrue(reader.hasUserDefined(Tags.OFFER)); - Offer o = reader.readUserDefined(Tags.OFFER, Offer.class); + assertTrue(p.hasOffer()); + Offer o = p.readOffer(); Collection<MessageId> offered = o.getMessageIds(); assertEquals(4, offered.size()); Iterator<MessageId> it1 = offered.iterator(); @@ -208,8 +217,8 @@ public class FileReadWriteTest extends TestCase { assertEquals(message3.getId(), it1.next()); // Read the request - assertTrue(reader.hasUserDefined(Tags.REQUEST)); - Request r = reader.readUserDefined(Tags.REQUEST, Request.class); + assertTrue(p.hasRequest()); + Request r = p.readRequest(); BitSet requested = r.getBitmap(); assertFalse(requested.get(0)); assertTrue(requested.get(1)); @@ -219,9 +228,8 @@ public class FileReadWriteTest extends TestCase { assertEquals(2, requested.cardinality()); // Read the subscription update - assertTrue(reader.hasUserDefined(Tags.SUBSCRIPTIONS)); - SubscriptionUpdate s = reader.readUserDefined(Tags.SUBSCRIPTIONS, - SubscriptionUpdate.class); + assertTrue(p.hasSubscriptionUpdate()); + SubscriptionUpdate s = p.readSubscriptionUpdate(); Map<Group, Long> subs = s.getSubscriptions(); assertEquals(2, subs.size()); assertEquals(Long.valueOf(0L), subs.get(group)); @@ -230,14 +238,11 @@ public class FileReadWriteTest extends TestCase { assertTrue(s.getTimestamp() <= System.currentTimeMillis()); // Read the transport update - assertTrue(reader.hasUserDefined(Tags.TRANSPORTS)); - TransportUpdate t = reader.readUserDefined(Tags.TRANSPORTS, - TransportUpdate.class); + assertTrue(p.hasTransportUpdate()); + TransportUpdate t = p.readTransportUpdate(); assertEquals(transports, t.getTransports()); assertTrue(t.getTimestamp() > start); assertTrue(t.getTimestamp() <= System.currentTimeMillis()); - - assertTrue(reader.eof()); } @After