diff --git a/components/net/sf/briar/protocol/DigestingConsumer.java b/components/net/sf/briar/protocol/DigestingConsumer.java index ee9eaed41fe500b36e34750c0f814de5d7d7dbc0..78826ff45be8aee78b0ccfed14c07a129b541294 100644 --- a/components/net/sf/briar/protocol/DigestingConsumer.java +++ b/components/net/sf/briar/protocol/DigestingConsumer.java @@ -1,6 +1,5 @@ package net.sf.briar.protocol; -import java.io.IOException; import java.security.MessageDigest; import net.sf.briar.api.serial.Consumer; @@ -14,15 +13,15 @@ class DigestingConsumer implements Consumer { this.messageDigest = messageDigest; } - public void write(byte b) throws IOException { + public void write(byte b) { messageDigest.update(b); } - public void write(byte[] b) throws IOException { + public void write(byte[] b) { messageDigest.update(b); } - public void write(byte[] b, int off, int len) throws IOException { + public void write(byte[] b, int off, int len) { messageDigest.update(b, off, len); } } diff --git a/components/net/sf/briar/transport/MacConsumer.java b/components/net/sf/briar/transport/MacConsumer.java new file mode 100644 index 0000000000000000000000000000000000000000..dbf20af8d12b458e1116511cc5001ee8e1a97322 --- /dev/null +++ b/components/net/sf/briar/transport/MacConsumer.java @@ -0,0 +1,27 @@ +package net.sf.briar.transport; + +import javax.crypto.Mac; + +import net.sf.briar.api.serial.Consumer; + +/** A consumer that passes its input through a MAC. */ +class MacConsumer implements Consumer { + + private final Mac mac; + + MacConsumer(Mac mac) { + this.mac = mac; + } + + public void write(byte b) { + mac.update(b); + } + + public void write(byte[] b) { + mac.update(b); + } + + public void write(byte[] b, int off, int len) { + mac.update(b, off, len); + } +} diff --git a/components/net/sf/briar/transport/PacketReaderImpl.java b/components/net/sf/briar/transport/PacketReaderImpl.java index 3ad2eb0f11e6c1f3e101a04ddc8b589aec8dfda1..7b654f1adc536c20c30bd64f0127d1ada9911ae8 100644 --- a/components/net/sf/briar/transport/PacketReaderImpl.java +++ b/components/net/sf/briar/transport/PacketReaderImpl.java @@ -2,6 +2,7 @@ package net.sf.briar.transport; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import javax.crypto.Mac; @@ -13,6 +14,7 @@ import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.writers.ProtocolReaderFactory; +import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.transport.PacketReader; @@ -22,7 +24,7 @@ class PacketReaderImpl implements PacketReader { private final Reader reader; private final PacketDecrypter decrypter; private final Mac mac; - private final int transportId; + private final int macLength, transportId; private final long connection; private long packet = 0L; @@ -41,8 +43,10 @@ class PacketReaderImpl implements PacketReader { protocol.createSubscriptionReader(in)); reader.addObjectReader(Tags.TRANSPORTS, protocol.createTransportReader(in)); + reader.addConsumer(new MacConsumer(mac)); this.decrypter = decrypter; this.mac = mac; + macLength = mac.getMacLength(); this.transportId = transportId; this.connection = connection; } @@ -62,7 +66,7 @@ class PacketReaderImpl implements PacketReader { throw new IllegalStateException(); byte[] tag = decrypter.readTag(); if(!TagDecoder.decodeTag(tag, transportId, connection, packet)) - throw new IOException(); + throw new FormatException(); mac.update(tag); packet++; betweenPackets = false; @@ -70,7 +74,24 @@ class PacketReaderImpl implements PacketReader { public Ack readAck() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.ACK, Ack.class); + Ack a = reader.readUserDefined(Tags.ACK, Ack.class); + readMac(); + betweenPackets = true; + return a; + } + + private void readMac() throws IOException { + byte[] expectedMac = mac.doFinal(); + byte[] actualMac = new byte[macLength]; + InputStream in = decrypter.getInputStream(); + int offset = 0; + while(offset < macLength) { + int read = in.read(actualMac, offset, actualMac.length - offset); + if(read == -1) break; + offset += read; + } + if(offset < macLength) throw new FormatException(); + if(!Arrays.equals(expectedMac, actualMac)) throw new FormatException(); } public boolean hasBatch() throws IOException { @@ -80,7 +101,10 @@ class PacketReaderImpl implements PacketReader { public Batch readBatch() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.BATCH, Batch.class); + Batch b = reader.readUserDefined(Tags.BATCH, Batch.class); + readMac(); + betweenPackets = true; + return b; } public boolean hasOffer() throws IOException { @@ -90,7 +114,10 @@ class PacketReaderImpl implements PacketReader { public Offer readOffer() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.OFFER, Offer.class); + Offer o = reader.readUserDefined(Tags.OFFER, Offer.class); + readMac(); + betweenPackets = true; + return o; } public boolean hasRequest() throws IOException { @@ -100,7 +127,10 @@ class PacketReaderImpl implements PacketReader { public Request readRequest() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.REQUEST, Request.class); + Request r = reader.readUserDefined(Tags.REQUEST, Request.class); + readMac(); + betweenPackets = true; + return r; } public boolean hasSubscriptionUpdate() throws IOException { @@ -110,8 +140,11 @@ class PacketReaderImpl implements PacketReader { public SubscriptionUpdate readSubscriptionUpdate() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.SUBSCRIPTIONS, + SubscriptionUpdate s = reader.readUserDefined(Tags.SUBSCRIPTIONS, SubscriptionUpdate.class); + readMac(); + betweenPackets = true; + return s; } public boolean hasTransportUpdate() throws IOException { @@ -121,6 +154,10 @@ class PacketReaderImpl implements PacketReader { public TransportUpdate readTransportUpdate() throws IOException { if(betweenPackets) readTag(); - return reader.readUserDefined(Tags.TRANSPORTS, TransportUpdate.class); + TransportUpdate t = reader.readUserDefined(Tags.TRANSPORTS, + TransportUpdate.class); + readMac(); + betweenPackets = true; + return t; } }