diff --git a/api/net/sf/briar/api/protocol/Ack.java b/api/net/sf/briar/api/protocol/Ack.java index 40ebe1865d8d9a605f80685901ad8f3c23ccf63a..c8822c2d5dad9f573e3b6c33ffcb8b86a02bf207 100644 --- a/api/net/sf/briar/api/protocol/Ack.java +++ b/api/net/sf/briar/api/protocol/Ack.java @@ -12,5 +12,5 @@ public interface Ack { static final int MAX_SIZE = (1024 * 1024) - 100; /** Returns the IDs of the acknowledged batches. */ - Collection<BatchId> getBatches(); + Collection<BatchId> getBatchIds(); } diff --git a/api/net/sf/briar/api/protocol/Offer.java b/api/net/sf/briar/api/protocol/Offer.java index e65d07394b222f77d25ab5b41bfcc132c1c2e4e6..e62f05d56aac13afa2bcf3863f6340f665d8512f 100644 --- a/api/net/sf/briar/api/protocol/Offer.java +++ b/api/net/sf/briar/api/protocol/Offer.java @@ -12,5 +12,5 @@ public interface Offer { static final int MAX_SIZE = (1024 * 1024) - 100; /** Returns the message IDs contained in the offer. */ - Collection<MessageId> getMessages(); + Collection<MessageId> getMessageIds(); } diff --git a/api/net/sf/briar/api/protocol/Request.java b/api/net/sf/briar/api/protocol/Request.java new file mode 100644 index 0000000000000000000000000000000000000000..ea3f461f3f01180739904fd438c13edf1feb4a0f --- /dev/null +++ b/api/net/sf/briar/api/protocol/Request.java @@ -0,0 +1,19 @@ +package net.sf.briar.api.protocol; + +import java.util.BitSet; + +/** A packet requesting some or all of the messages from an offer. */ +public interface Request { + + /** + * The maximum size of a serialised request, exlcuding encryption and + * authentication. + */ + static final int MAX_SIZE = (1024 * 1024) - 100; + + /** + * Returns a sequence of bits corresponding to the sequence of messages in + * the offer, where the i^th bit is set if the i^th message should be sent. + */ + BitSet getBitmap(); +} diff --git a/components/net/sf/briar/db/ReadWriteLockDatabaseComponent.java b/components/net/sf/briar/db/ReadWriteLockDatabaseComponent.java index de0139a6f24451961f045e5a63c5a9340dab376d..53537c863b0d4ac27521c01bab9d22a3f55b74d6 100644 --- a/components/net/sf/briar/db/ReadWriteLockDatabaseComponent.java +++ b/components/net/sf/briar/db/ReadWriteLockDatabaseComponent.java @@ -597,7 +597,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> { try { messageStatusLock.writeLock().lock(); try { - Collection<BatchId> acks = a.getBatches(); + Collection<BatchId> acks = a.getBatchIds(); for(BatchId ack : acks) { Txn txn = db.startTransaction(); try { @@ -676,7 +676,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> { try { subscriptionLock.readLock().lock(); try { - Collection<MessageId> offered = o.getMessages(); + Collection<MessageId> offered = o.getMessageIds(); BitSet request = new BitSet(offered.size()); Txn txn = db.startTransaction(); try { diff --git a/components/net/sf/briar/db/SynchronizedDatabaseComponent.java b/components/net/sf/briar/db/SynchronizedDatabaseComponent.java index 079ef8254c6c4a6d810b365f4bc8defe5d7effb9..0c3c796af2713652cf1b9c0f67d1ba2eff7417f2 100644 --- a/components/net/sf/briar/db/SynchronizedDatabaseComponent.java +++ b/components/net/sf/briar/db/SynchronizedDatabaseComponent.java @@ -440,7 +440,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> { if(!containsContact(c)) throw new NoSuchContactException(); synchronized(messageLock) { synchronized(messageStatusLock) { - Collection<BatchId> acks = a.getBatches(); + Collection<BatchId> acks = a.getBatchIds(); for(BatchId ack : acks) { Txn txn = db.startTransaction(); try { @@ -497,7 +497,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> { synchronized(messageLock) { synchronized(messageStatusLock) { synchronized(subscriptionLock) { - Collection<MessageId> offered = o.getMessages(); + Collection<MessageId> offered = o.getMessageIds(); BitSet request = new BitSet(offered.size()); Txn txn = db.startTransaction(); try { diff --git a/components/net/sf/briar/protocol/AckFactory.java b/components/net/sf/briar/protocol/AckFactory.java index 8ed23f90b106c9f8a78a9f703e70d47f2d48ab51..9c574db5914b6cc988e1ff8909f17efef8c87f31 100644 --- a/components/net/sf/briar/protocol/AckFactory.java +++ b/components/net/sf/briar/protocol/AckFactory.java @@ -7,5 +7,5 @@ import net.sf.briar.api.protocol.BatchId; interface AckFactory { - Ack createAck(Collection<BatchId> batches); + Ack createAck(Collection<BatchId> acked); } diff --git a/components/net/sf/briar/protocol/AckFactoryImpl.java b/components/net/sf/briar/protocol/AckFactoryImpl.java index 900022e62173d47d5454fd27f126f1b35c7dbf5c..f08715c8f50c734784aa74e4dbe1815f7ddc4067 100644 --- a/components/net/sf/briar/protocol/AckFactoryImpl.java +++ b/components/net/sf/briar/protocol/AckFactoryImpl.java @@ -7,7 +7,7 @@ import net.sf.briar.api.protocol.BatchId; class AckFactoryImpl implements AckFactory { - public Ack createAck(Collection<BatchId> batches) { - return new AckImpl(batches); + public Ack createAck(Collection<BatchId> acked) { + return new AckImpl(acked); } } diff --git a/components/net/sf/briar/protocol/AckImpl.java b/components/net/sf/briar/protocol/AckImpl.java index 06e34b25b889f395175b621805b7cc78835b8d4e..9b956ff462d155053438ba39ec34596cceb654d3 100644 --- a/components/net/sf/briar/protocol/AckImpl.java +++ b/components/net/sf/briar/protocol/AckImpl.java @@ -7,13 +7,13 @@ import net.sf.briar.api.protocol.BatchId; class AckImpl implements Ack { - private final Collection<BatchId> batches; + private final Collection<BatchId> acked; - AckImpl(Collection<BatchId> batches) { - this.batches = batches; + AckImpl(Collection<BatchId> acked) { + this.acked = acked; } - public Collection<BatchId> getBatches() { - return batches; + public Collection<BatchId> getBatchIds() { + return acked; } } diff --git a/components/net/sf/briar/protocol/OfferFactory.java b/components/net/sf/briar/protocol/OfferFactory.java index 23fbea52d438578072dcaa7142c0d9e4fb9c15b1..6f19d4e29192edd0eadd298b67d398f050391337 100644 --- a/components/net/sf/briar/protocol/OfferFactory.java +++ b/components/net/sf/briar/protocol/OfferFactory.java @@ -7,5 +7,5 @@ import net.sf.briar.api.protocol.Offer; interface OfferFactory { - Offer createOffer(Collection<MessageId> messages); + Offer createOffer(Collection<MessageId> offered); } diff --git a/components/net/sf/briar/protocol/OfferFactoryImpl.java b/components/net/sf/briar/protocol/OfferFactoryImpl.java index 8635b0b96adf24738f3b3f2080c3616ade4e8d62..075527d14768faa86d022cfe308f8f7eda203b4b 100644 --- a/components/net/sf/briar/protocol/OfferFactoryImpl.java +++ b/components/net/sf/briar/protocol/OfferFactoryImpl.java @@ -7,7 +7,7 @@ import net.sf.briar.api.protocol.Offer; class OfferFactoryImpl implements OfferFactory { - public Offer createOffer(Collection<MessageId> messages) { - return new OfferImpl(messages); + public Offer createOffer(Collection<MessageId> offered) { + return new OfferImpl(offered); } } diff --git a/components/net/sf/briar/protocol/OfferImpl.java b/components/net/sf/briar/protocol/OfferImpl.java index 129d2054e5e68eb78edbc1ac4bf3cf3e1102ff2f..de892202e64c2fc1fe435767993734920aeda5cf 100644 --- a/components/net/sf/briar/protocol/OfferImpl.java +++ b/components/net/sf/briar/protocol/OfferImpl.java @@ -7,13 +7,13 @@ import net.sf.briar.api.protocol.Offer; class OfferImpl implements Offer { - private final Collection<MessageId> messages; + private final Collection<MessageId> offered; - OfferImpl(Collection<MessageId> messages) { - this.messages = messages; + OfferImpl(Collection<MessageId> offered) { + this.offered = offered; } - public Collection<MessageId> getMessages() { - return messages; + public Collection<MessageId> getMessageIds() { + return offered; } } diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index ffd8e51c9806025ccf6ce8bffb9abfe8813916ff..57e6e2c6953721b64698e6c124eb0116cee76bd2 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -23,19 +23,21 @@ public class ProtocolModule extends AbstractModule { bind(BatchFactory.class).to(BatchFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class); bind(OfferFactory.class).to(OfferFactoryImpl.class); + bind(RequestFactory.class).to(RequestFactoryImpl.class); bind(SubscriptionFactory.class).to(SubscriptionFactoryImpl.class); bind(TransportFactory.class).to(TransportFactoryImpl.class); bind(MessageEncoder.class).to(MessageEncoderImpl.class); } @Provides - ObjectReader<BatchId> getBatchIdReader() { - return new BatchIdReader(); + ObjectReader<Author> getAuthorReader(CryptoComponent crypto, + AuthorFactory authorFactory) { + return new AuthorReader(crypto, authorFactory); } @Provides - ObjectReader<MessageId> getMessageIdReader() { - return new MessageIdReader(); + ObjectReader<BatchId> getBatchIdReader() { + return new BatchIdReader(); } @Provides @@ -45,9 +47,8 @@ public class ProtocolModule extends AbstractModule { } @Provides - ObjectReader<Author> getAuthorReader(CryptoComponent crypto, - AuthorFactory authorFactory) { - return new AuthorReader(crypto, authorFactory); + ObjectReader<MessageId> getMessageIdReader() { + return new MessageIdReader(); } @Provides diff --git a/components/net/sf/briar/protocol/RequestFactory.java b/components/net/sf/briar/protocol/RequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..005982b826a5c32ae72efceb1fe563d24d6406ec --- /dev/null +++ b/components/net/sf/briar/protocol/RequestFactory.java @@ -0,0 +1,10 @@ +package net.sf.briar.protocol; + +import java.util.BitSet; + +import net.sf.briar.api.protocol.Request; + +interface RequestFactory { + + Request createRequest(BitSet requested); +} diff --git a/components/net/sf/briar/protocol/RequestFactoryImpl.java b/components/net/sf/briar/protocol/RequestFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..0c2c77cb1cd324c4f292adeba8da1125dfbd21f2 --- /dev/null +++ b/components/net/sf/briar/protocol/RequestFactoryImpl.java @@ -0,0 +1,12 @@ +package net.sf.briar.protocol; + +import java.util.BitSet; + +import net.sf.briar.api.protocol.Request; + +class RequestFactoryImpl implements RequestFactory { + + public Request createRequest(BitSet requested) { + return new RequestImpl(requested); + } +} diff --git a/components/net/sf/briar/protocol/RequestImpl.java b/components/net/sf/briar/protocol/RequestImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..ddb31898aeb2f4f02ae581ce889eaebf22894d09 --- /dev/null +++ b/components/net/sf/briar/protocol/RequestImpl.java @@ -0,0 +1,18 @@ +package net.sf.briar.protocol; + +import java.util.BitSet; + +import net.sf.briar.api.protocol.Request; + +class RequestImpl implements Request { + + private final BitSet requested; + + RequestImpl(BitSet requested) { + this.requested = requested; + } + + public BitSet getBitmap() { + return requested; + } +} diff --git a/components/net/sf/briar/protocol/RequestReader.java b/components/net/sf/briar/protocol/RequestReader.java new file mode 100644 index 0000000000000000000000000000000000000000..0fe33531fc89b26ae8302dbd715ccae1060b7426 --- /dev/null +++ b/components/net/sf/briar/protocol/RequestReader.java @@ -0,0 +1,41 @@ +package net.sf.briar.protocol; + +import java.io.IOException; +import java.util.BitSet; + +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.Tags; +import net.sf.briar.api.serial.Consumer; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.Reader; + +import com.google.inject.Inject; + +class RequestReader implements ObjectReader<Request> { + + private final RequestFactory requestFactory; + + @Inject + RequestReader(RequestFactory requestFactory) { + this.requestFactory = requestFactory; + } + + public Request readObject(Reader r) throws IOException { + // Initialise the consumer + Consumer counting = new CountingConsumer(Request.MAX_SIZE); + // Read the data + r.addConsumer(counting); + r.readUserDefinedTag(Tags.REQUEST); + byte[] bitmap = r.readBytes(); + r.removeConsumer(counting); + // Convert the bitmap into a BitSet + BitSet b = new BitSet(bitmap.length * 8); + for(int i = 0; i < bitmap.length; i++) { + for(int j = 0; j < 8; j++) { + byte bit = (byte) (128 >> j); + if((bitmap[i] & bit) != 0) b.set(i * 8 + j); + } + } + return requestFactory.createRequest(b); + } +} diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java index 6f843d19116ee71369242eb609e21a987c1b706e..87f967537c4a3049ddfc91a3bf019a30b5e27d6f 100644 --- a/test/net/sf/briar/db/DatabaseComponentTest.java +++ b/test/net/sf/briar/db/DatabaseComponentTest.java @@ -752,7 +752,7 @@ public abstract class DatabaseComponentTest extends TestCase { allowing(database).containsContact(txn, contactId); will(returnValue(true)); // Get the acked batches - oneOf(ack).getBatches(); + oneOf(ack).getBatchIds(); will(returnValue(Collections.singletonList(batchId))); oneOf(database).removeAckedBatch(txn, contactId, batchId); }}); @@ -940,7 +940,7 @@ public abstract class DatabaseComponentTest extends TestCase { allowing(database).containsContact(txn, contactId); will(returnValue(true)); // Get the offered messages - oneOf(offer).getMessages(); + oneOf(offer).getMessageIds(); will(returnValue(offered)); oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId); will(returnValue(false)); // Not visible - request message # 0 diff --git a/test/net/sf/briar/protocol/AckReaderTest.java b/test/net/sf/briar/protocol/AckReaderTest.java index 815553e28bdbf1fd8fe629cf7cd4a54d52876fe3..1bfbdb824354e4c0dd7c6db2924a7f9ae32f18bc 100644 --- a/test/net/sf/briar/protocol/AckReaderTest.java +++ b/test/net/sf/briar/protocol/AckReaderTest.java @@ -97,7 +97,7 @@ public class AckReaderTest extends TestCase { } private byte[] createAck(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(Ack.MAX_SIZE); + ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); w.writeUserDefinedTag(Tags.ACK); w.writeListStart(); diff --git a/test/net/sf/briar/protocol/FileReadWriteTest.java b/test/net/sf/briar/protocol/FileReadWriteTest.java index 9b77922d90458ee1c02918f204f7d43e73ef4a08..57e29a55d6b4d38a5d08e51c9239d7c389072b21 100644 --- a/test/net/sf/briar/protocol/FileReadWriteTest.java +++ b/test/net/sf/briar/protocol/FileReadWriteTest.java @@ -6,6 +6,7 @@ import java.io.FileOutputStream; import java.security.KeyPair; import java.util.ArrayList; import java.util.Arrays; +import java.util.BitSet; import java.util.Collection; import java.util.Collections; import java.util.Iterator; @@ -24,6 +25,7 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Subscriptions; import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.Transports; @@ -32,6 +34,7 @@ import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.OfferWriter; import net.sf.briar.api.protocol.writers.PacketWriterFactory; +import net.sf.briar.api.protocol.writers.RequestWriter; import net.sf.briar.api.protocol.writers.SubscriptionWriter; import net.sf.briar.api.protocol.writers.TransportWriter; import net.sf.briar.api.serial.Reader; @@ -61,6 +64,7 @@ public class FileReadWriteTest extends TestCase { private final AckReader ackReader; private final BatchReader batchReader; private final OfferReader offerReader; + private final RequestReader requestReader; private final SubscriptionReader subscriptionReader; private final TransportReader transportReader; private final Author author; @@ -82,6 +86,7 @@ public class FileReadWriteTest extends TestCase { ackReader = i.getInstance(AckReader.class); batchReader = i.getInstance(BatchReader.class); offerReader = i.getInstance(OfferReader.class); + requestReader = i.getInstance(RequestReader.class); subscriptionReader = i.getInstance(SubscriptionReader.class); transportReader = i.getInstance(TransportReader.class); // Create two groups: one restricted, one unrestricted @@ -135,6 +140,12 @@ public class FileReadWriteTest extends TestCase { assertTrue(o.writeMessageId(message3.getId())); o.finish(); + RequestWriter r = packetWriterFactory.createRequestWriter(out); + BitSet requested = new BitSet(4); + requested.set(1); + requested.set(3); + r.writeBitmap(requested, 4); + SubscriptionWriter s = packetWriterFactory.createSubscriptionWriter(out); Collection<Group> subs = new ArrayList<Group>(); @@ -160,35 +171,47 @@ public class FileReadWriteTest extends TestCase { reader.addObjectReader(Tags.ACK, ackReader); reader.addObjectReader(Tags.BATCH, batchReader); reader.addObjectReader(Tags.OFFER, offerReader); + reader.addObjectReader(Tags.REQUEST, requestReader); reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader); reader.addObjectReader(Tags.TRANSPORTS, transportReader); // Read the ack assertTrue(reader.hasUserDefined(Tags.ACK)); Ack a = reader.readUserDefined(Tags.ACK, Ack.class); - assertEquals(Collections.singletonList(ack), a.getBatches()); + assertEquals(Collections.singletonList(ack), a.getBatchIds()); // Read the batch assertTrue(reader.hasUserDefined(Tags.BATCH)); Batch b = reader.readUserDefined(Tags.BATCH, Batch.class); Collection<Message> messages = b.getMessages(); assertEquals(4, messages.size()); - Iterator<Message> i = messages.iterator(); - checkMessageEquality(message, i.next()); - checkMessageEquality(message1, i.next()); - checkMessageEquality(message2, i.next()); - checkMessageEquality(message3, i.next()); + Iterator<Message> it = messages.iterator(); + checkMessageEquality(message, it.next()); + checkMessageEquality(message1, it.next()); + checkMessageEquality(message2, it.next()); + checkMessageEquality(message3, it.next()); // Read the offer assertTrue(reader.hasUserDefined(Tags.OFFER)); Offer o = reader.readUserDefined(Tags.OFFER, Offer.class); - Collection<MessageId> ids = o.getMessages(); - assertEquals(4, ids.size()); - Iterator<MessageId> i1 = ids.iterator(); - assertEquals(message.getId(), i1.next()); - assertEquals(message1.getId(), i1.next()); - assertEquals(message2.getId(), i1.next()); - assertEquals(message3.getId(), i1.next()); + Collection<MessageId> offered = o.getMessageIds(); + assertEquals(4, offered.size()); + Iterator<MessageId> it1 = offered.iterator(); + assertEquals(message.getId(), it1.next()); + assertEquals(message1.getId(), it1.next()); + assertEquals(message2.getId(), it1.next()); + assertEquals(message3.getId(), it1.next()); + + // Read the request + assertTrue(reader.hasUserDefined(Tags.REQUEST)); + Request r = reader.readUserDefined(Tags.REQUEST, Request.class); + BitSet requested = r.getBitmap(); + assertFalse(requested.get(0)); + assertTrue(requested.get(1)); + assertFalse(requested.get(2)); + assertTrue(requested.get(3)); + // If there are any padding bits, they should all be zero + for(int i = 4; i < requested.size(); i++) assertFalse(requested.get(i)); // Read the subscriptions update assertTrue(reader.hasUserDefined(Tags.SUBSCRIPTIONS)); @@ -196,9 +219,9 @@ public class FileReadWriteTest extends TestCase { Subscriptions.class); Collection<Group> subs = s.getSubscriptions(); assertEquals(2, subs.size()); - Iterator<Group> i2 = subs.iterator(); - checkGroupEquality(group, i2.next()); - checkGroupEquality(group1, i2.next()); + Iterator<Group> it2 = subs.iterator(); + checkGroupEquality(group, it2.next()); + checkGroupEquality(group1, it2.next()); assertTrue(s.getTimestamp() > start); assertTrue(s.getTimestamp() <= System.currentTimeMillis()); diff --git a/test/net/sf/briar/protocol/RequestReaderTest.java b/test/net/sf/briar/protocol/RequestReaderTest.java new file mode 100644 index 0000000000000000000000000000000000000000..e17022098f9043a06e7558c8e28960857af4cf32 --- /dev/null +++ b/test/net/sf/briar/protocol/RequestReaderTest.java @@ -0,0 +1,133 @@ +package net.sf.briar.protocol; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.BitSet; + +import junit.framework.TestCase; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.Tags; +import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.Reader; +import net.sf.briar.api.serial.ReaderFactory; +import net.sf.briar.api.serial.Writer; +import net.sf.briar.api.serial.WriterFactory; +import net.sf.briar.serial.SerialModule; + +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class RequestReaderTest extends TestCase { + + private final ReaderFactory readerFactory; + private final WriterFactory writerFactory; + private final Mockery context; + + public RequestReaderTest() throws Exception { + super(); + Injector i = Guice.createInjector(new SerialModule()); + readerFactory = i.getInstance(ReaderFactory.class); + writerFactory = i.getInstance(WriterFactory.class); + context = new Mockery(); + } + + @Test + public void testFormatExceptionIfRequestIsTooLarge() throws Exception { + RequestFactory requestFactory = context.mock(RequestFactory.class); + RequestReader requestReader = new RequestReader(requestFactory); + + byte[] b = createRequest(true); + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + reader.addObjectReader(Tags.REQUEST, requestReader); + + try { + reader.readUserDefined(Tags.REQUEST, Request.class); + assertTrue(false); + } catch(FormatException expected) {} + context.assertIsSatisfied(); + } + + @Test + public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception { + final RequestFactory requestFactory = + context.mock(RequestFactory.class); + RequestReader requestReader = new RequestReader(requestFactory); + final Request request = context.mock(Request.class); + context.checking(new Expectations() {{ + oneOf(requestFactory).createRequest(with(any(BitSet.class))); + will(returnValue(request)); + }}); + + byte[] b = createRequest(false); + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + reader.addObjectReader(Tags.REQUEST, requestReader); + + assertEquals(request, reader.readUserDefined(Tags.REQUEST, + Request.class)); + context.assertIsSatisfied(); + } + + @Test + public void testBitmapDecoding() throws Exception { + // Test sizes from 0 to 1000 bits + for(int i = 0; i < 1000; i++) { + // Create a BitSet of size i with one in ten bits set (on average) + BitSet requested = new BitSet(i); + for(int j = 0; j < i; j++) if(Math.random() < 0.1) requested.set(j); + // Encode the BitSet as a bitmap + int bytes = i % 8 == 0 ? i / 8 : i / 8 + 1; + byte[] bitmap = new byte[bytes]; + for(int j = 0; j < i; j++) { + if(requested.get(j)) { + int offset = j / 8; + byte bit = (byte) (128 >> j % 8); + bitmap[offset] |= bit; + } + } + // Create a serialised request containing the bitmap + byte[] b = createRequest(bitmap); + // Deserialise the request + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + RequestReader requestReader = + new RequestReader(new RequestFactoryImpl()); + reader.addObjectReader(Tags.REQUEST, requestReader); + Request r = reader.readUserDefined(Tags.REQUEST, Request.class); + BitSet decoded = r.getBitmap(); + // Check that the decoded BitSet matches the original - we can't + // use equals() because of padding, but the first i bits should + // match and the cardinalities should be equal, indicating that no + // padding bits are set + for(int j = 0; j < i; j++) { + assertEquals(requested.get(j), decoded.get(j)); + } + assertEquals(requested.cardinality(), decoded.cardinality()); + } + } + + private byte[] createRequest(boolean tooBig) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Writer w = writerFactory.createWriter(out); + w.writeUserDefinedTag(Tags.REQUEST); + // Allow one byte for the REQUEST tag, one byte for the BYTES tag, and + // five bytes for the length as an int32 + if(tooBig) w.writeBytes(new byte[Request.MAX_SIZE - 6]); + else w.writeBytes(new byte[Request.MAX_SIZE - 7]); + assertEquals(tooBig, out.size() > Request.MAX_SIZE); + return out.toByteArray(); + } + + private byte[] createRequest(byte[] bitmap) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Writer w = writerFactory.createWriter(out); + w.writeUserDefinedTag(Tags.REQUEST); + w.writeBytes(bitmap); + return out.toByteArray(); + } +}