diff --git a/api/net/sf/briar/api/serial/ObjectReader.java b/api/net/sf/briar/api/serial/ObjectReader.java index 2f345e9ce26c24e8846fe36194d986e3969d43c0..012f5f800aae92025570e16520fc8b38e81d5a16 100644 --- a/api/net/sf/briar/api/serial/ObjectReader.java +++ b/api/net/sf/briar/api/serial/ObjectReader.java @@ -1,9 +1,8 @@ package net.sf.briar.api.serial; import java.io.IOException; -import java.security.GeneralSecurityException; public interface ObjectReader<T> { - T readObject(Reader r) throws IOException, GeneralSecurityException; + T readObject(Reader r) throws IOException; } diff --git a/api/net/sf/briar/api/serial/Reader.java b/api/net/sf/briar/api/serial/Reader.java index b1b7997d95ba86d7003ce8c8533f7d6536cedc2a..9b80ebc4361964a502f7a07c9314f8c0228aded2 100644 --- a/api/net/sf/briar/api/serial/Reader.java +++ b/api/net/sf/briar/api/serial/Reader.java @@ -1,7 +1,6 @@ package net.sf.briar.api.serial; import java.io.IOException; -import java.security.GeneralSecurityException; import java.util.List; import java.util.Map; @@ -43,18 +42,16 @@ public interface Reader { byte[] readRaw() throws IOException; boolean hasList() throws IOException; - List<Object> readList() throws IOException, GeneralSecurityException; - <E> List<E> readList(Class<E> e) throws IOException, - GeneralSecurityException; + List<Object> readList() throws IOException; + <E> List<E> readList(Class<E> e) throws IOException; boolean hasListStart() throws IOException; void readListStart() throws IOException; boolean hasListEnd() throws IOException; void readListEnd() throws IOException; boolean hasMap() throws IOException; - Map<Object, Object> readMap() throws IOException, GeneralSecurityException; - <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException, - GeneralSecurityException; + Map<Object, Object> readMap() throws IOException; + <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException; boolean hasMapStart() throws IOException; void readMapStart() throws IOException; boolean hasMapEnd() throws IOException; @@ -63,9 +60,7 @@ public interface Reader { boolean hasNull() throws IOException; void readNull() throws IOException; - boolean hasUserDefinedTag() throws IOException; - int readUserDefinedTag() throws IOException; + boolean hasUserDefined(int tag) throws IOException; + <T> T readUserDefined(int tag, Class<T> t) throws IOException; void readUserDefinedTag(int tag) throws IOException; - <T> T readUserDefinedObject(int tag, Class<T> t) throws IOException, - GeneralSecurityException; } diff --git a/components/net/sf/briar/protocol/BatchIdReader.java b/components/net/sf/briar/protocol/BatchIdReader.java index 5114141f19ec9d7d0adefebc6a1f47f5f2e9d813..c084dede86c05d3bdac5b1f04346da1f4cefb288 100644 --- a/components/net/sf/briar/protocol/BatchIdReader.java +++ b/components/net/sf/briar/protocol/BatchIdReader.java @@ -3,6 +3,7 @@ package net.sf.briar.protocol; import java.io.IOException; import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.ObjectReader; @@ -11,6 +12,7 @@ import net.sf.briar.api.serial.Reader; public class BatchIdReader implements ObjectReader<BatchId> { public BatchId readObject(Reader r) throws IOException { + r.readUserDefinedTag(Tags.BATCH_ID); byte[] b = r.readRaw(); if(b.length != UniqueId.LENGTH) throw new FormatException(); return new BatchId(b); diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java index 5cc9496ea75ef7bc5b983a8eccc4ef87720cfa67..a9611649f2c28d67322a35f80595e3ea4f122478 100644 --- a/components/net/sf/briar/protocol/BatchReader.java +++ b/components/net/sf/briar/protocol/BatchReader.java @@ -1,7 +1,6 @@ package net.sf.briar.protocol; import java.io.IOException; -import java.security.GeneralSecurityException; import java.security.MessageDigest; import java.util.List; @@ -25,21 +24,20 @@ public class BatchReader implements ObjectReader<Batch> { this.batchFactory = batchFactory; } - public Batch readObject(Reader reader) throws IOException, - GeneralSecurityException { - // Initialise the consumers - the initial tag has already been read, so - // subtract one from the maximum size - CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE - 1); + public Batch readObject(Reader r) throws IOException { + // Initialise the consumers + CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE); DigestingConsumer digesting = new DigestingConsumer(messageDigest); messageDigest.reset(); // Read and digest the data - reader.addConsumer(counting); - reader.addConsumer(digesting); - reader.addObjectReader(Tags.MESSAGE, messageReader); - List<Message> messages = reader.readList(Message.class); - reader.removeObjectReader(Tags.MESSAGE); - reader.removeConsumer(digesting); - reader.removeConsumer(counting); + r.addConsumer(counting); + r.addConsumer(digesting); + r.readUserDefinedTag(Tags.BATCH); + r.addObjectReader(Tags.MESSAGE, messageReader); + List<Message> messages = r.readList(Message.class); + r.removeObjectReader(Tags.MESSAGE); + r.removeConsumer(digesting); + r.removeConsumer(counting); // Build and return the batch BatchId id = new BatchId(messageDigest.digest()); return batchFactory.createBatch(id, messages); diff --git a/components/net/sf/briar/protocol/BundleReaderImpl.java b/components/net/sf/briar/protocol/BundleReaderImpl.java index e82bdb6282f63ebc99328cea47fe83f4d145aa74..684a0b97bd91843cb9d65b70613444efc5939879 100644 --- a/components/net/sf/briar/protocol/BundleReaderImpl.java +++ b/components/net/sf/briar/protocol/BundleReaderImpl.java @@ -29,9 +29,8 @@ class BundleReaderImpl implements BundleReader { public Header getHeader() throws IOException, GeneralSecurityException { if(state != State.START) throw new IllegalStateException(); - reader.readUserDefinedTag(Tags.HEADER); reader.addObjectReader(Tags.HEADER, headerReader); - Header h = reader.readUserDefinedObject(Tags.HEADER, Header.class); + Header h = reader.readUserDefined(Tags.HEADER, Header.class); reader.removeObjectReader(Tags.HEADER); // Expect a list of batches reader.readListStart(); @@ -50,8 +49,7 @@ class BundleReaderImpl implements BundleReader { state = State.END; return null; } - reader.readUserDefinedTag(Tags.BATCH); - return reader.readUserDefinedObject(Tags.BATCH, Batch.class); + return reader.readUserDefined(Tags.BATCH, Batch.class); } public void finish() throws IOException { diff --git a/components/net/sf/briar/protocol/GroupIdReader.java b/components/net/sf/briar/protocol/GroupIdReader.java index bc1b1ab1be9ee9a1d395e1e7cae74b5e802f9ff8..17b58d14162e60da59906d120c36497e4e079225 100644 --- a/components/net/sf/briar/protocol/GroupIdReader.java +++ b/components/net/sf/briar/protocol/GroupIdReader.java @@ -3,6 +3,7 @@ package net.sf.briar.protocol; import java.io.IOException; import net.sf.briar.api.protocol.GroupId; +import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.ObjectReader; @@ -11,6 +12,7 @@ import net.sf.briar.api.serial.Reader; public class GroupIdReader implements ObjectReader<GroupId> { public GroupId readObject(Reader r) throws IOException { + r.readUserDefinedTag(Tags.GROUP_ID); byte[] b = r.readRaw(); if(b.length != UniqueId.LENGTH) throw new FormatException(); return new GroupId(b); diff --git a/components/net/sf/briar/protocol/HeaderReader.java b/components/net/sf/briar/protocol/HeaderReader.java index e42db8d03f562e2ceb55796604a587dc00ad284e..6c75d9d73cfe8dc311abc015e38b08c56fb0b96c 100644 --- a/components/net/sf/briar/protocol/HeaderReader.java +++ b/components/net/sf/briar/protocol/HeaderReader.java @@ -1,7 +1,6 @@ package net.sf.briar.protocol; import java.io.IOException; -import java.security.GeneralSecurityException; import java.util.Collection; import java.util.Map; @@ -21,28 +20,27 @@ class HeaderReader implements ObjectReader<Header> { this.headerFactory = headerFactory; } - public Header readObject(Reader reader) throws IOException, - GeneralSecurityException { - // Initialise and add the consumer - the initial tag has already been - // read, so subtract one from the maximum size - CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE - 1); - reader.addConsumer(counting); + public Header readObject(Reader r) throws IOException { + // Initialise and add the consumer + CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE); + r.addConsumer(counting); + r.readUserDefinedTag(Tags.HEADER); // Acks - reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader()); - Collection<BatchId> acks = reader.readList(BatchId.class); - reader.removeObjectReader(Tags.BATCH_ID); + r.addObjectReader(Tags.BATCH_ID, new BatchIdReader()); + Collection<BatchId> acks = r.readList(BatchId.class); + r.removeObjectReader(Tags.BATCH_ID); // Subs - reader.addObjectReader(Tags.GROUP_ID, new GroupIdReader()); - Collection<GroupId> subs = reader.readList(GroupId.class); - reader.removeObjectReader(Tags.GROUP_ID); + r.addObjectReader(Tags.GROUP_ID, new GroupIdReader()); + Collection<GroupId> subs = r.readList(GroupId.class); + r.removeObjectReader(Tags.GROUP_ID); // Transports Map<String, String> transports = - reader.readMap(String.class, String.class); + r.readMap(String.class, String.class); // Timestamp - long timestamp = reader.readInt64(); + long timestamp = r.readInt64(); if(timestamp < 0L) throw new FormatException(); // Remove the consumer - reader.removeConsumer(counting); + r.removeConsumer(counting); // Build and return the header return headerFactory.createHeader(acks, subs, transports, timestamp); } diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java index ba4bad5394f94d07181dce7f37910806b6398c98..3258ae9568a82dd921c08d52fda501b920c565e6 100644 --- a/components/net/sf/briar/protocol/MessageReader.java +++ b/components/net/sf/briar/protocol/MessageReader.java @@ -32,40 +32,42 @@ class MessageReader implements ObjectReader<Message> { this.messageDigest = messageDigest; } - public Message readObject(Reader reader) throws IOException { + public Message readObject(Reader r) throws IOException { CopyingConsumer copying = new CopyingConsumer(); CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE); - reader.addConsumer(copying); - reader.addConsumer(counting); + r.addConsumer(copying); + r.addConsumer(counting); + // Read the initial tag + r.readUserDefinedTag(Tags.MESSAGE); // Read the parent's message ID - reader.readUserDefinedTag(Tags.MESSAGE_ID); - byte[] b = reader.readRaw(); + r.readUserDefinedTag(Tags.MESSAGE_ID); + byte[] b = r.readRaw(); if(b.length != UniqueId.LENGTH) throw new FormatException(); MessageId parent = new MessageId(b); // Read the group ID - reader.readUserDefinedTag(Tags.GROUP_ID); - b = reader.readRaw(); + r.readUserDefinedTag(Tags.GROUP_ID); + b = r.readRaw(); if(b.length != UniqueId.LENGTH) throw new FormatException(); GroupId group = new GroupId(b); // Read the timestamp - long timestamp = reader.readInt64(); + long timestamp = r.readInt64(); if(timestamp < 0L) throw new FormatException(); // Hash the author's nick and public key to get the author ID DigestingConsumer digesting = new DigestingConsumer(messageDigest); messageDigest.reset(); - reader.addConsumer(digesting); - reader.readString(); - byte[] encodedKey = reader.readRaw(); - reader.removeConsumer(digesting); + r.addConsumer(digesting); + r.readString(); + byte[] encodedKey = r.readRaw(); + r.removeConsumer(digesting); AuthorId author = new AuthorId(messageDigest.digest()); // Skip the message body - reader.readRaw(); + r.readRaw(); // Record the length of the signed data int messageLength = (int) counting.getCount(); // Read the signature - byte[] sig = reader.readRaw(); - reader.removeConsumer(counting); - reader.removeConsumer(copying); + byte[] sig = r.readRaw(); + r.removeConsumer(counting); + r.removeConsumer(copying); // Verify the signature PublicKey publicKey; try { diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java index be019665f92a850ef22de6218b853ec64998d840..648e94a10ca9aeceede0bddddfa44aa228d9d1f7 100644 --- a/components/net/sf/briar/serial/ReaderImpl.java +++ b/components/net/sf/briar/serial/ReaderImpl.java @@ -2,7 +2,6 @@ package net.sf.briar.serial; import java.io.IOException; import java.io.InputStream; -import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -20,12 +19,11 @@ class ReaderImpl implements Reader { private static final byte[] EMPTY_BUFFER = new byte[] {}; private final InputStream in; - private final Map<Integer, ObjectReader<?>> objectReaders = - new HashMap<Integer, ObjectReader<?>>(); private Consumer[] consumers = new Consumer[] {}; + private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {}; private boolean started = false, eof = false; - private byte next; + private byte next, nextNext; private byte[] buf = null; ReaderImpl(InputStream in) { @@ -39,16 +37,32 @@ class ReaderImpl implements Reader { private byte readNext(boolean eofAcceptable) throws IOException { assert !eof; - if(started) for(Consumer c : consumers) c.write(next); + if(started) { + for(Consumer c : consumers) { + c.write(next); + if(next == Tag.USER) c.write(nextNext); + } + } started = true; + readLookahead(eofAcceptable); + return next; + } + + private void readLookahead(boolean eofAcceptable) throws IOException { + assert started; + // Read the lookahead byte int i = in.read(); if(i == -1) { - eof = true; if(!eofAcceptable) throw new FormatException(); + eof = true; } - if(i > 127) i -= 256; next = (byte) i; - return next; + // If necessary, read another lookahead byte + if(next == Tag.USER) { + i = in.read(); + if(i == -1) throw new FormatException(); + nextNext = (byte) i; + } } public void close() throws IOException { @@ -78,11 +92,20 @@ class ReaderImpl implements Reader { } public void addObjectReader(int tag, ObjectReader<?> o) { - objectReaders.put(tag, o); + if(tag < 0 || tag > 255) throw new IllegalArgumentException(); + if(objectReaders.length < tag + 1) { + ObjectReader<?>[] newObjectReaders = new ObjectReader<?>[tag + 1]; + System.arraycopy(objectReaders, 0, newObjectReaders, 0, + objectReaders.length); + objectReaders = newObjectReaders; + } + objectReaders[tag] = o; } public void removeObjectReader(int tag) { - objectReaders.remove(tag); + if(tag < 0 || tag > objectReaders.length) + throw new IllegalArgumentException(); + objectReaders[tag] = null; } public boolean hasBoolean() throws IOException { @@ -165,19 +188,22 @@ class ReaderImpl implements Reader { private void readIntoBuffer(byte[] b, int length) throws IOException { b[0] = next; int offset = 1; + if(next == Tag.USER) { + b[1] = nextNext; + offset = 2; + } while(offset < length) { int read = in.read(b, offset, length - offset); - if(read == -1) break; + if(read == -1) { + eof = true; + break; + } offset += read; } if(offset < length) throw new FormatException(); // Feed the hungry mouths for(Consumer c : consumers) c.write(b, 0, length); - // Read the lookahead byte - int i = in.read(); - if(i == -1) eof = true; - if(i > 127) i -= 256; - next = (byte) i; + readLookahead(true); } public boolean hasInt64() throws IOException { @@ -316,13 +342,11 @@ class ReaderImpl implements Reader { || (next & Tag.SHORT_MASK) == Tag.SHORT_LIST; } - public List<Object> readList() throws IOException, - GeneralSecurityException { + public List<Object> readList() throws IOException { return readList(Object.class); } - public <E> List<E> readList(Class<E> e) throws IOException, - GeneralSecurityException { + public <E> List<E> readList(Class<E> e) throws IOException { if(!hasList()) throw new FormatException(); if(next == Tag.LIST) { readNext(false); @@ -340,8 +364,7 @@ class ReaderImpl implements Reader { } } - private <E> List<E> readList(Class<E> e, int length) throws IOException, - GeneralSecurityException { + private <E> List<E> readList(Class<E> e, int length) throws IOException { assert length >= 0; List<E> list = new ArrayList<E>(); for(int i = 0; i < length; i++) list.add(readObject(e)); @@ -360,13 +383,9 @@ class ReaderImpl implements Reader { readNext(true); } - private Object readObject() throws IOException, GeneralSecurityException { + private Object readObject() throws IOException { if(!started) throw new IllegalStateException(); - if(hasUserDefinedTag()) { - ObjectReader<?> o = objectReaders.get(readUserDefinedTag()); - if(o == null) throw new FormatException(); - return o.readObject(this); - } + if(hasUserDefined()) return readUserDefined(); if(hasBoolean()) return Boolean.valueOf(readBoolean()); if(hasUint7()) return Byte.valueOf(readUint7()); if(hasInt8()) return Byte.valueOf(readInt8()); @@ -386,8 +405,23 @@ class ReaderImpl implements Reader { throw new FormatException(); } - private <T> T readObject(Class<T> t) throws IOException, - GeneralSecurityException { + private boolean hasUserDefined() throws IOException { + if(!started) readNext(true); + if(eof) return false; + if(next == Tag.USER) return true; + if((next & Tag.SHORT_USER_MASK) == Tag.SHORT_USER) return true; + return false; + } + + private Object readUserDefined() throws IOException { + if(!hasUserDefined()) throw new FormatException(); + int tag; + if(next == Tag.USER) tag = 0xFF & nextNext; + else tag = 0xFF & next ^ Tag.SHORT_USER; + return readUserDefined(tag, Object.class); + } + + private <T> T readObject(Class<T> t) throws IOException { try { return t.cast(readObject()); } catch(ClassCastException e) { @@ -421,13 +455,11 @@ class ReaderImpl implements Reader { || (next & Tag.SHORT_MASK) == Tag.SHORT_MAP; } - public Map<Object, Object> readMap() throws IOException, - GeneralSecurityException { + public Map<Object, Object> readMap() throws IOException { return readMap(Object.class, Object.class); } - public <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException, - GeneralSecurityException { + public <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException { if(!hasMap()) throw new FormatException(); if(next == Tag.MAP) { readNext(false); @@ -446,7 +478,7 @@ class ReaderImpl implements Reader { } private <K, V> Map<K, V> readMap(Class<K> k, Class<V> v, int size) - throws IOException, GeneralSecurityException { + throws IOException { assert size >= 0; Map<K, V> m = new HashMap<K, V>(); for(int i = 0; i < size; i++) m.put(readObject(k), readObject(v)); @@ -483,32 +515,21 @@ class ReaderImpl implements Reader { readNext(true); } - public boolean hasUserDefinedTag() throws IOException { + public boolean hasUserDefined(int tag) throws IOException { + if(tag < 0 || tag > 255) throw new IllegalArgumentException(); if(!started) readNext(true); if(eof) return false; - return next == Tag.USER || - (next & Tag.SHORT_USER_MASK) == Tag.SHORT_USER; - } - - public int readUserDefinedTag() throws IOException { - if(!hasUserDefinedTag()) throw new FormatException(); - if(next == Tag.USER) { - readNext(false); - return readLength(); - } else { - int tag = 0xFF & next ^ Tag.SHORT_USER; - readNext(true); - return tag; - } - } - - public void readUserDefinedTag(int tag) throws IOException { - if(readUserDefinedTag() != tag) throw new FormatException(); + if(next == Tag.USER) + return tag == (0xFF & nextNext); + else if((next & Tag.SHORT_USER_MASK) == Tag.SHORT_USER) + return tag == (0xFF & next ^ Tag.SHORT_USER); + else return false; } - public <T> T readUserDefinedObject(int tag, Class<T> t) throws IOException, - GeneralSecurityException { - ObjectReader<?> o = objectReaders.get(tag); + public <T> T readUserDefined(int tag, Class<T> t) throws IOException { + if(!hasUserDefined(tag)) throw new FormatException(); + if(tag >= objectReaders.length) throw new FormatException(); + ObjectReader<?> o = objectReaders[tag]; if(o == null) throw new FormatException(); try { return t.cast(o.readObject(this)); @@ -516,4 +537,9 @@ class ReaderImpl implements Reader { throw new FormatException(); } } + + public void readUserDefinedTag(int tag) throws IOException { + if(!hasUserDefined(tag)) throw new FormatException(); + readNext(false); + } } diff --git a/components/net/sf/briar/serial/WriterImpl.java b/components/net/sf/briar/serial/WriterImpl.java index 9f7360075a84b1a9902cdcc1c1e0cf17bc0da506..bb27a56aaaede2c4f948ada5a32b41b892835fd0 100644 --- a/components/net/sf/briar/serial/WriterImpl.java +++ b/components/net/sf/briar/serial/WriterImpl.java @@ -111,7 +111,7 @@ class WriterImpl implements Writer { public void writeString(String s) throws IOException { byte[] b = s.getBytes("UTF-8"); - if(b.length < 16) out.write(intToByte(Tag.SHORT_STRING | b.length)); + if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length)); else { out.write(Tag.STRING); writeLength(b.length); @@ -120,12 +120,6 @@ class WriterImpl implements Writer { bytesWritten += b.length + 1; } - private byte intToByte(int i) { - assert i >= 0; - assert i <= 255; - return (byte) (i > 127 ? i - 256 : i); - } - private void writeLength(int i) throws IOException { assert i >= 0; // Fun fact: it's never worth writing a length as an int8 @@ -135,7 +129,7 @@ class WriterImpl implements Writer { } public void writeRaw(byte[] b) throws IOException { - if(b.length < 16) out.write(intToByte(Tag.SHORT_RAW | b.length)); + if(b.length < 16) out.write((byte) (Tag.SHORT_RAW | b.length)); else { out.write(Tag.RAW); writeLength(b.length); @@ -150,7 +144,7 @@ class WriterImpl implements Writer { public void writeList(Collection<?> c) throws IOException { int length = c.size(); - if(length < 16) out.write(intToByte(Tag.SHORT_LIST | length)); + if(length < 16) out.write((byte) (Tag.SHORT_LIST | length)); else { out.write(Tag.LIST); writeLength(length); @@ -188,7 +182,7 @@ class WriterImpl implements Writer { public void writeMap(Map<?, ?> m) throws IOException { int length = m.size(); - if(length < 16) out.write(intToByte(Tag.SHORT_MAP | length)); + if(length < 16) out.write((byte) (Tag.SHORT_MAP | length)); else { out.write(Tag.MAP); writeLength(length); @@ -216,12 +210,14 @@ class WriterImpl implements Writer { } public void writeUserDefinedTag(int tag) throws IOException { - if(tag < 0) throw new IllegalArgumentException(); - if(tag < 32) out.write((byte) (Tag.SHORT_USER | tag)); - else { + if(tag < 0 || tag > 255) throw new IllegalArgumentException(); + if(tag < 32) { + out.write((byte) (Tag.SHORT_USER | tag)); + bytesWritten++; + } else { out.write(Tag.USER); - writeLength(tag); + out.write((byte) tag); + bytesWritten += 2; } - bytesWritten++; } } diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java index 9c93620c65904e4ceffa0af6c5be8d4fb9c71dc7..6eaa425be7698385c3e823b7b0c24b5833cc234e 100644 --- a/test/net/sf/briar/protocol/BatchReaderTest.java +++ b/test/net/sf/briar/protocol/BatchReaderTest.java @@ -3,7 +3,6 @@ package net.sf.briar.protocol; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.security.GeneralSecurityException; import java.security.MessageDigest; import java.util.Collections; @@ -65,9 +64,8 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.BATCH, batchReader); - reader.readUserDefinedTag(Tags.BATCH); try { - reader.readUserDefinedObject(Tags.BATCH, Batch.class); + reader.readUserDefined(Tags.BATCH, Batch.class); assertTrue(false); } catch(FormatException expected) {} context.assertIsSatisfied(); @@ -91,17 +89,15 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.BATCH, batchReader); - reader.readUserDefinedTag(Tags.BATCH); - assertEquals(batch, reader.readUserDefinedObject(Tags.BATCH, - Batch.class)); + assertEquals(batch, reader.readUserDefined(Tags.BATCH, Batch.class)); context.assertIsSatisfied(); } @Test public void testBatchId() throws Exception { byte[] b = createBatch(Batch.MAX_SIZE); - // Calculate the expected batch ID, skipping the initial tag - messageDigest.update(b, 1, b.length - 1); + // Calculate the expected batch ID + messageDigest.update(b); final BatchId id = new BatchId(messageDigest.digest()); messageDigest.reset(); @@ -121,9 +117,7 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.BATCH, batchReader); - reader.readUserDefinedTag(Tags.BATCH); - assertEquals(batch, reader.readUserDefinedObject(Tags.BATCH, - Batch.class)); + assertEquals(batch, reader.readUserDefined(Tags.BATCH, Batch.class)); context.assertIsSatisfied(); } @@ -145,9 +139,7 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.BATCH, batchReader); - reader.readUserDefinedTag(Tags.BATCH); - assertEquals(batch, reader.readUserDefinedObject(Tags.BATCH, - Batch.class)); + assertEquals(batch, reader.readUserDefined(Tags.BATCH, Batch.class)); context.assertIsSatisfied(); } @@ -156,8 +148,8 @@ public class BatchReaderTest extends TestCase { Writer w = writerFactory.createWriter(out); w.writeUserDefinedTag(Tags.BATCH); w.writeListStart(); - w.writeUserDefinedTag(Tags.MESSAGE); // We're using a fake message reader, so it's OK to use a fake message + w.writeUserDefinedTag(Tags.MESSAGE); w.writeRaw(new byte[size - 10]); w.writeListEnd(); w.close(); @@ -178,8 +170,8 @@ public class BatchReaderTest extends TestCase { private class TestMessageReader implements ObjectReader<Message> { - public Message readObject(Reader r) throws IOException, - GeneralSecurityException { + public Message readObject(Reader r) throws IOException { + r.readUserDefinedTag(Tags.MESSAGE); r.readRaw(); return message; } diff --git a/test/net/sf/briar/protocol/BundleReaderImplTest.java b/test/net/sf/briar/protocol/BundleReaderImplTest.java index 7506fbfdd45169e5793a2137542c35c7118a8853..90a358bd9e96a36a44057501378333d6835313e6 100644 --- a/test/net/sf/briar/protocol/BundleReaderImplTest.java +++ b/test/net/sf/briar/protocol/BundleReaderImplTest.java @@ -3,7 +3,6 @@ package net.sf.briar.protocol; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.security.GeneralSecurityException; import java.util.Collections; import junit.framework.TestCase; @@ -217,8 +216,7 @@ public class BundleReaderImplTest extends TestCase { private class TestHeaderReader implements ObjectReader<Header> { - public Header readObject(Reader r) throws IOException, - GeneralSecurityException { + public Header readObject(Reader r) throws IOException { r.readList(); r.readList(); r.readMap(); @@ -229,8 +227,7 @@ public class BundleReaderImplTest extends TestCase { private class TestBatchReader implements ObjectReader<Batch> { - public Batch readObject(Reader r) throws IOException, - GeneralSecurityException { + public Batch readObject(Reader r) throws IOException { r.readList(); return context.mock(Batch.class); } diff --git a/test/net/sf/briar/protocol/HeaderReaderTest.java b/test/net/sf/briar/protocol/HeaderReaderTest.java index 329e422449ceb0358154fd394178367e74fbc3e5..af2648e9bfb69da85da1c9cca680fcf478c62e97 100644 --- a/test/net/sf/briar/protocol/HeaderReaderTest.java +++ b/test/net/sf/briar/protocol/HeaderReaderTest.java @@ -51,9 +51,8 @@ public class HeaderReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.HEADER, headerReader); - reader.readUserDefinedTag(Tags.HEADER); try { - reader.readUserDefinedObject(Tags.HEADER, Header.class); + reader.readUserDefined(Tags.HEADER, Header.class); assertTrue(false); } catch(FormatException expected) {} context.assertIsSatisfied(); @@ -78,9 +77,7 @@ public class HeaderReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.HEADER, headerReader); - reader.readUserDefinedTag(Tags.HEADER); - assertEquals(header, reader.readUserDefinedObject(Tags.HEADER, - Header.class)); + assertEquals(header, reader.readUserDefined(Tags.HEADER, Header.class)); context.assertIsSatisfied(); } @@ -103,26 +100,23 @@ public class HeaderReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Tags.HEADER, headerReader); - reader.readUserDefinedTag(Tags.HEADER); - assertEquals(header, reader.readUserDefinedObject(Tags.HEADER, - Header.class)); + assertEquals(header, reader.readUserDefined(Tags.HEADER, Header.class)); context.assertIsSatisfied(); } private byte[] createHeader(int size) throws Exception { - Random random = new Random(); ByteArrayOutputStream out = new ByteArrayOutputStream(size); Writer w = writerFactory.createWriter(out); w.writeUserDefinedTag(Tags.HEADER); - // Acks + // No acks w.writeListStart(); w.writeListEnd(); - // Subs + // Fill most of the header with subs w.writeListStart(); - // Fill most of the header with subscriptions - while(w.getBytesWritten() < size - 45) { + byte[] b = new byte[UniqueId.LENGTH]; + Random random = new Random(); + while(out.size() < size - 60) { w.writeUserDefinedTag(Tags.GROUP_ID); - byte[] b = new byte[UniqueId.LENGTH]; random.nextBytes(b); w.writeRaw(b); } @@ -131,7 +125,8 @@ public class HeaderReaderTest extends TestCase { w.writeMapStart(); w.writeString("foo"); // Build a string that will bring the header up to the expected size - int length = (int) (size - w.getBytesWritten() - 12); + int length = size - out.size() - 12; + assertTrue(length > 0); StringBuilder s = new StringBuilder(); for(int i = 0; i < length; i++) s.append((char) ('0' + i % 10)); w.writeString(s.toString()); @@ -139,9 +134,8 @@ public class HeaderReaderTest extends TestCase { // Timestamp w.writeInt64(System.currentTimeMillis()); w.close(); - byte[] b = out.toByteArray(); - assertEquals(size, b.length); - return b; + assertEquals(size, out.size()); + return out.toByteArray(); } private byte[] createEmptyHeader() throws Exception { diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java index 33f24720867ef2615bdc7431c9a3492ddb72e1b0..5a52f28d85594bc0628dd6a473462897c9043512 100644 --- a/test/net/sf/briar/serial/ReaderImplTest.java +++ b/test/net/sf/briar/serial/ReaderImplTest.java @@ -1,6 +1,7 @@ package net.sf.briar.serial; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -8,6 +9,7 @@ import java.util.Map; import java.util.Map.Entry; import junit.framework.TestCase; +import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Raw; @@ -331,35 +333,77 @@ public class ReaderImplTest extends TestCase { } @Test - public void testReadUserDefinedTag() throws Exception { - setContents("C0" + "DF" + "EF" + "20" + "EF" + "FB7FFFFFFF"); - assertEquals(0, r.readUserDefinedTag()); - assertEquals(31, r.readUserDefinedTag()); - assertEquals(32, r.readUserDefinedTag()); - assertEquals(Integer.MAX_VALUE, r.readUserDefinedTag()); - assertTrue(r.eof()); + public void testReadUserDefined() throws Exception { + setContents("C0" + "83666F6F" + "EF" + "FF" + "83666F6F"); + // Add object readers for two user-defined types + r.addObjectReader(0, new ObjectReader<Foo>() { + public Foo readObject(Reader r) throws IOException { + r.readUserDefinedTag(0); + return new Foo(r.readString()); + } + }); + r.addObjectReader(255, new ObjectReader<Bar>() { + public Bar readObject(Reader r) throws IOException { + r.readUserDefinedTag(255); + return new Bar(r.readString()); + } + }); + // Test both tag formats, short and long + assertTrue(r.hasUserDefined(0)); + assertEquals("foo", r.readUserDefined(0, Foo.class).s); + assertTrue(r.hasUserDefined(255)); + assertEquals("foo", r.readUserDefined(255, Bar.class).s); } @Test - public void testReadUserDefinedObject() throws Exception { - setContents("C0" + "83666F6F"); - // Add an object reader for a user-defined type + public void testReadUserDefinedWithConsumer() throws Exception { + setContents("C0" + "83666F6F" + "EF" + "FF" + "83666F6F"); + // Add object readers for two user-defined types r.addObjectReader(0, new ObjectReader<Foo>() { public Foo readObject(Reader r) throws IOException { + r.readUserDefinedTag(0); return new Foo(r.readString()); } }); - assertEquals(0, r.readUserDefinedTag()); - assertEquals("foo", r.readUserDefinedObject(0, Foo.class).s); + r.addObjectReader(255, new ObjectReader<Bar>() { + public Bar readObject(Reader r) throws IOException { + r.readUserDefinedTag(255); + return new Bar(r.readString()); + } + }); + // Add a consumer + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + r.addConsumer(new Consumer() { + + public void write(byte b) throws IOException { + out.write(b); + } + + public void write(byte[] b) throws IOException { + out.write(b); + } + + public void write(byte[] b, int off, int len) throws IOException { + out.write(b, off, len); + } + }); + // Test both tag formats, short and long + assertTrue(r.hasUserDefined(0)); + assertEquals("foo", r.readUserDefined(0, Foo.class).s); + assertTrue(r.hasUserDefined(255)); + assertEquals("foo", r.readUserDefined(255, Bar.class).s); + // Check that everything was passed to the consumer + assertEquals("C0" + "83666F6F" + "EF" + "FF" + "83666F6F", + StringUtils.toHexString(out.toByteArray())); } @Test public void testUnknownTagThrowsFormatException() throws Exception { setContents("C0" + "83666F6F"); + assertTrue(r.hasUserDefined(0)); // No object reader has been added for tag 0 - assertEquals(0, r.readUserDefinedTag()); try { - r.readUserDefinedObject(0, Foo.class); + r.readUserDefined(0, Foo.class); assertTrue(false); } catch(FormatException expected) {} } @@ -370,13 +414,14 @@ public class ReaderImplTest extends TestCase { // Add an object reader for tag 0, class Foo r.addObjectReader(0, new ObjectReader<Foo>() { public Foo readObject(Reader r) throws IOException { + r.readUserDefinedTag(0); return new Foo(r.readString()); } }); - assertEquals(0, r.readUserDefinedTag()); + assertTrue(r.hasUserDefined(0)); // Trying to read the object as class Bar should throw a FormatException try { - r.readUserDefinedObject(0, Bar.class); + r.readUserDefined(0, Bar.class); assertTrue(false); } catch(FormatException expected) {} } @@ -387,6 +432,7 @@ public class ReaderImplTest extends TestCase { // Add an object reader for a user-defined type r.addObjectReader(0, new ObjectReader<Foo>() { public Foo readObject(Reader r) throws IOException { + r.readUserDefinedTag(0); return new Foo(r.readString()); } }); @@ -402,11 +448,13 @@ public class ReaderImplTest extends TestCase { // Add object readers for two user-defined types r.addObjectReader(0, new ObjectReader<Foo>() { public Foo readObject(Reader r) throws IOException { + r.readUserDefinedTag(0); return new Foo(r.readString()); } }); r.addObjectReader(1, new ObjectReader<Bar>() { public Bar readObject(Reader r) throws IOException { + r.readUserDefinedTag(1); return new Bar(r.readString()); } }); diff --git a/test/net/sf/briar/serial/WriterImplTest.java b/test/net/sf/briar/serial/WriterImplTest.java index 6bc4b7c2f58e2c010c28032745073ab6a0da3e2b..f71eb9edec08ac54c44ef263e345799ba2ac6f22 100644 --- a/test/net/sf/briar/serial/WriterImplTest.java +++ b/test/net/sf/briar/serial/WriterImplTest.java @@ -298,9 +298,9 @@ public class WriterImplTest extends TestCase { @Test public void testWriteUserDefinedTag() throws IOException { w.writeUserDefinedTag(32); - w.writeUserDefinedTag(Integer.MAX_VALUE); - // USER tag, 32 as uint7, USER tag, 2147483647 as int32 - checkContents("EF" + "20" + "EF" + "FB7FFFFFFF"); + w.writeUserDefinedTag(255); + // USER tag, 32 as uint8, USER tag, 255 as uint8 + checkContents("EF" + "20" + "EF" + "FF"); } @Test @@ -323,6 +323,5 @@ public class WriterImplTest extends TestCase { byte[] expected = StringUtils.fromHexString(hex); assertTrue(StringUtils.toHexString(out.toByteArray()), Arrays.equals(expected, out.toByteArray())); - assertEquals(expected.length, w.getBytesWritten()); } } diff --git a/util/net/sf/briar/util/StringUtils.java b/util/net/sf/briar/util/StringUtils.java index e99bf14879d2f90721d5a7361fe3385b57a2298f..cc0c92c18861d5c720aac7563fdca23d801bd6c9 100644 --- a/util/net/sf/briar/util/StringUtils.java +++ b/util/net/sf/briar/util/StringUtils.java @@ -45,9 +45,7 @@ public class StringUtils { for(int i = 0, j = 0; i < len; i += 2, j++) { int high = hexDigitToInt(hex.charAt(i)); int low = hexDigitToInt(hex.charAt(i + 1)); - int b = (high << 4) + low; - if(b > 127) b -= 256; - bytes[j] = (byte) b; + bytes[j] = (byte) ((high << 4) + low); } return bytes; }