diff --git a/api/net/sf/briar/api/protocol/Message.java b/api/net/sf/briar/api/protocol/Message.java index 0f145f01e53debb275ca2d174f6675ca5ca6d52d..7a3506dc71f2aba3a2a6060eb976d9a50a99a856 100644 --- a/api/net/sf/briar/api/protocol/Message.java +++ b/api/net/sf/briar/api/protocol/Message.java @@ -1,7 +1,5 @@ package net.sf.briar.api.protocol; -import java.io.InputStream; - public interface Message extends MessageHeader { /** @@ -21,15 +19,15 @@ public interface Message extends MessageHeader { /** The length of the random salt in bytes. */ static final int SALT_LENGTH = 8; - /** Returns the length of the message in bytes. */ + /** Returns the length of the serialised message in bytes. */ int getLength(); - /** Returns the serialised representation of the entire message. */ - byte[] getSerialisedBytes(); + /** Returns the serialised message. */ + byte[] getSerialised(); - /** - * Returns a stream for reading the serialised representation of the entire - * message. - */ - InputStream getSerialisedStream(); + /** Returns the offset of the message body within the serialised message. */ + int getBodyStart(); + + /** Returns the length of the message body in bytes. */ + int getBodyLength(); } \ No newline at end of file diff --git a/components/net/sf/briar/db/Database.java b/components/net/sf/briar/db/Database.java index f5dd6e39ae874826c29277a7aea34024615134e2..9a7eaf6beafec274381dfdfeade182f1fc31a91e 100644 --- a/components/net/sf/briar/db/Database.java +++ b/components/net/sf/briar/db/Database.java @@ -239,12 +239,19 @@ interface Database<T> { Collection<BatchId> getLostBatches(T txn, ContactId c) throws DbException; /** - * Returns the message identified by the given ID, in raw format. + * Returns the message identified by the given ID, in serialised form. * <p> * Locking: messages read. */ byte[] getMessage(T txn, MessageId m) throws DbException; + /** + * Returns the body of the message identified by the given ID. + * <p> + * Locking: messages read. + */ + byte[] getMessageBody(T txn, MessageId m) throws DbException; + /** * Returns the message identified by the given ID, in raw format, or null * if the message is not present in the database or is not sendable to the diff --git a/components/net/sf/briar/db/DatabaseComponentImpl.java b/components/net/sf/briar/db/DatabaseComponentImpl.java index 356338b1d9081d3c62425c6144441f03f8a152a2..ea996eb7760b6138db68d5d71705b0852bd1d152 100644 --- a/components/net/sf/briar/db/DatabaseComponentImpl.java +++ b/components/net/sf/briar/db/DatabaseComponentImpl.java @@ -432,8 +432,7 @@ DatabaseCleaner.Callback { int capacity = b.getCapacity(); ids = db.getSendableMessages(txn, c, capacity); for(MessageId m : ids) { - byte[] raw = db.getMessage(txn, m); - messages.add(new Bytes(raw)); + messages.add(new Bytes(db.getMessage(txn, m))); } db.commitTransaction(txn); } catch(DbException e) { diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index c69b16d4ca81f388253b857725f2724ac5fb1873..6d7bed07fa7bad00519063bfb2252fd8578c5427 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -1,7 +1,9 @@ package net.sf.briar.db; +import java.io.EOFException; import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -61,8 +63,11 @@ abstract class JdbcDatabase implements Database<Connection> { + " parentId HASH," + " groupId HASH," + " authorId HASH," + + " subject VARCHAR NOT NULL," + " timestamp BIGINT NOT NULL," - + " size INT NOT NULL," + + " length INT NOT NULL," + + " bodyStart INT NOT NULL," + + " bodyLength INT NOT NULL," + " raw BLOB NOT NULL," + " sendability INT," + " contactId INT," @@ -536,10 +541,10 @@ abstract class JdbcDatabase implements Database<Connection> { if(containsMessage(txn, m.getId())) return false; PreparedStatement ps = null; try { - String sql = "INSERT INTO messages" - + " (messageId, parentId, groupId, authorId, timestamp, size," - + " raw, sendability)" - + " VALUES (?, ?, ?, ?, ?, ?, ?, ZERO())"; + String sql = "INSERT INTO messages (messageId, parentId, groupId," + + " authorId, subject, timestamp, length, bodyStart," + + " bodyLength, raw, sendability)" + + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ZERO())"; ps = txn.prepareStatement(sql); ps.setBytes(1, m.getId().getBytes()); if(m.getParent() == null) ps.setNull(2, Types.BINARY); @@ -547,10 +552,12 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setBytes(3, m.getGroup().getBytes()); if(m.getAuthor() == null) ps.setNull(4, Types.BINARY); else ps.setBytes(4, m.getAuthor().getBytes()); - ps.setLong(5, m.getTimestamp()); - int length = m.getLength(); - ps.setInt(6, length); - ps.setBinaryStream(7, m.getSerialisedStream(), length); + ps.setString(5, m.getSubject()); + ps.setLong(6, m.getTimestamp()); + ps.setInt(7, m.getLength()); + ps.setInt(8, m.getBodyStart()); + ps.setInt(9, m.getBodyLength()); + ps.setBytes(10, m.getSerialised()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -626,18 +633,20 @@ abstract class JdbcDatabase implements Database<Connection> { if(containsMessage(txn, m.getId())) return false; PreparedStatement ps = null; try { - String sql = "INSERT INTO messages" - + " (messageId, parentId, timestamp, size, raw, contactId)" - + " VALUES (?, ?, ?, ?, ?, ?)"; + String sql = "INSERT INTO messages (messageId, parentId, subject," + + " timestamp, length, bodyStart, bodyLength, raw, contactId)" + + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setBytes(1, m.getId().getBytes()); if(m.getParent() == null) ps.setNull(2, Types.BINARY); else ps.setBytes(2, m.getParent().getBytes()); - ps.setLong(3, m.getTimestamp()); - int length = m.getLength(); - ps.setInt(4, length); - ps.setBinaryStream(5, m.getSerialisedStream(), length); - ps.setInt(6, c.getInt()); + ps.setString(3, m.getSubject()); + ps.setLong(4, m.getTimestamp()); + ps.setInt(5, m.getLength()); + ps.setInt(6, m.getBodyStart()); + ps.setInt(7, m.getBodyLength()); + ps.setBytes(8, m.getSerialised()); + ps.setInt(9, c.getInt()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -1042,14 +1051,14 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT size, raw FROM messages WHERE messageId = ?"; + String sql = "SELECT length, raw FROM messages WHERE messageId = ?"; ps = txn.prepareStatement(sql); ps.setBytes(1, m.getBytes()); rs = ps.executeQuery(); if(!rs.next()) throw new DbStateException(); - int size = rs.getInt(1); - byte[] raw = rs.getBlob(2).getBytes(1, size); - if(raw.length != size) throw new DbStateException(); + int length = rs.getInt(1); + byte[] raw = rs.getBlob(2).getBytes(1, length); + if(raw.length != length) throw new DbStateException(); if(rs.next()) throw new DbStateException(); rs.close(); ps.close(); @@ -1061,13 +1070,60 @@ abstract class JdbcDatabase implements Database<Connection> { } } + public byte[] getMessageBody(Connection txn, MessageId m) + throws DbException { + PreparedStatement ps = null; + ResultSet rs = null; + try { + String sql = "SELECT bodyStart, bodyLength, raw FROM messages" + + " WHERE messageId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, m.getBytes()); + rs = ps.executeQuery(); + if(!rs.next()) throw new DbStateException(); + int bodyStart = rs.getInt(1); + int bodyLength = rs.getInt(2); + InputStream in = rs.getBlob(3).getBinaryStream(); + // FIXME: We have to read and discard the header because + // InputStream.skip() is broken for blobs - find out why + byte[] head = new byte[bodyStart]; + byte[] body = new byte[bodyLength]; + try { + int offset = 0; + while(offset < head.length) { + int read = in.read(head, offset, head.length - offset); + if(read == -1) throw new SQLException(new EOFException()); + offset += read; + } + offset = 0; + while(offset < body.length) { + int read = in.read(body, offset, body.length - offset); + if(read == -1) throw new SQLException(new EOFException()); + offset += read; + } + in.close(); + } catch(IOException e) { + if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); + throw new SQLException(e); + } + if(rs.next()) throw new DbStateException(); + rs.close(); + ps.close(); + return body; + } catch(SQLException e) { + tryToClose(rs); + tryToClose(ps); + throw new DbException(e); + } + } + public byte[] getMessageIfSendable(Connection txn, ContactId c, MessageId m) throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { // Do we have a sendable private message with the given ID? - String sql = "SELECT size, raw FROM messages" + String sql = "SELECT length, raw FROM messages" + " JOIN statuses ON messages.messageId = statuses.messageId" + " WHERE messages.messageId = ? AND messages.contactId = ?" + " AND status = ?"; @@ -1078,16 +1134,16 @@ abstract class JdbcDatabase implements Database<Connection> { rs = ps.executeQuery(); byte[] raw = null; if(rs.next()) { - int size = rs.getInt(1); - raw = rs.getBlob(2).getBytes(1, size); - if(raw.length != size) throw new DbStateException(); + int length = rs.getInt(1); + raw = rs.getBlob(2).getBytes(1, length); + if(raw.length != length) throw new DbStateException(); } if(rs.next()) throw new DbStateException(); rs.close(); ps.close(); if(raw != null) return raw; // Do we have a sendable group message with the given ID? - sql = "SELECT size, raw FROM messages" + sql = "SELECT length, raw FROM messages" + " JOIN contactSubscriptions" + " ON messages.groupId = contactSubscriptions.groupId" + " JOIN visibilities" @@ -1107,9 +1163,9 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setShort(3, (short) Status.NEW.ordinal()); rs = ps.executeQuery(); if(rs.next()) { - int size = rs.getInt(1); - raw = rs.getBlob(2).getBytes(1, size); - if(raw.length != size) throw new DbStateException(); + int length = rs.getInt(1); + raw = rs.getBlob(2).getBytes(1, length); + if(raw.length != length) throw new DbStateException(); } if(rs.next()) throw new DbStateException(); rs.close(); @@ -1203,17 +1259,17 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT size, messageId FROM messages" + String sql = "SELECT length, messageId FROM messages" + " ORDER BY timestamp"; ps = txn.prepareStatement(sql); rs = ps.executeQuery(); Collection<MessageId> ids = new ArrayList<MessageId>(); int total = 0; while(rs.next()) { - int size = rs.getInt(1); - if(total + size > capacity) break; + int length = rs.getInt(1); + if(total + length > capacity) break; ids.add(new MessageId(rs.getBytes(2))); - total += size; + total += length; } rs.close(); ps.close(); @@ -1361,7 +1417,7 @@ abstract class JdbcDatabase implements Database<Connection> { ResultSet rs = null; try { // Do we have any sendable private messages? - String sql = "SELECT size, messages.messageId FROM messages" + String sql = "SELECT length, messages.messageId FROM messages" + " JOIN statuses ON messages.messageId = statuses.messageId" + " WHERE messages.contactId = ? AND status = ?" + " ORDER BY timestamp"; @@ -1372,10 +1428,10 @@ abstract class JdbcDatabase implements Database<Connection> { Collection<MessageId> ids = new ArrayList<MessageId>(); int total = 0; while(rs.next()) { - int size = rs.getInt(1); - if(total + size > capacity) break; + int length = rs.getInt(1); + if(total + length > capacity) break; ids.add(new MessageId(rs.getBytes(2))); - total += size; + total += length; } rs.close(); ps.close(); @@ -1384,7 +1440,7 @@ abstract class JdbcDatabase implements Database<Connection> { total + "/" + capacity + " bytes"); if(total == capacity) return ids; // Do we have any sendable group messages? - sql = "SELECT size, messages.messageId FROM messages" + sql = "SELECT length, messages.messageId FROM messages" + " JOIN contactSubscriptions" + " ON messages.groupId = contactSubscriptions.groupId" + " JOIN visibilities" @@ -1403,10 +1459,10 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setShort(2, (short) Status.NEW.ordinal()); rs = ps.executeQuery(); while(rs.next()) { - int size = rs.getInt(1); - if(total + size > capacity) break; + int length = rs.getInt(1); + if(total + length > capacity) break; ids.add(new MessageId(rs.getBytes(2))); - total += size; + total += length; } rs.close(); ps.close(); diff --git a/components/net/sf/briar/protocol/MessageEncoderImpl.java b/components/net/sf/briar/protocol/MessageEncoderImpl.java index 39ab3b47f33a08ba12b339e5f94f8835868d06fa..eeef6c410541a860e426df36a107536215b38b7c 100644 --- a/components/net/sf/briar/protocol/MessageEncoderImpl.java +++ b/components/net/sf/briar/protocol/MessageEncoderImpl.java @@ -16,6 +16,7 @@ import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Writer; @@ -81,6 +82,9 @@ class MessageEncoderImpl implements MessageEncoder { ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); // Initialise the consumers + CountingConsumer counting = new CountingConsumer( + ProtocolConstants.MAX_PACKET_LENGTH); + w.addConsumer(counting); Consumer digestingConsumer = new DigestingConsumer(messageDigest); w.addConsumer(digestingConsumer); Consumer authorConsumer = null; @@ -110,6 +114,7 @@ class MessageEncoderImpl implements MessageEncoder { random.nextBytes(salt); w.writeBytes(salt); w.writeBytes(body); + int bodyStart = (int) counting.getCount() - body.length; // Sign the message with the author's private key, if there is one if(authorKey == null) { w.writeNull(); @@ -137,6 +142,6 @@ class MessageEncoderImpl implements MessageEncoder { GroupId groupId = group == null ? null : group.getId(); AuthorId authorId = author == null ? null : author.getId(); return new MessageImpl(id, parent, groupId, authorId, subject, - timestamp, raw); + timestamp, raw, bodyStart, body.length); } } diff --git a/components/net/sf/briar/protocol/MessageImpl.java b/components/net/sf/briar/protocol/MessageImpl.java index 33b06eab06d5b93b966715511b36617baac7c8f4..6e35f76102d6434698402b776d1fc5ce1697f044 100644 --- a/components/net/sf/briar/protocol/MessageImpl.java +++ b/components/net/sf/briar/protocol/MessageImpl.java @@ -1,8 +1,5 @@ package net.sf.briar.protocol; -import java.io.ByteArrayInputStream; -import java.io.InputStream; - import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; @@ -17,9 +14,15 @@ class MessageImpl implements Message { private final String subject; private final long timestamp; private final byte[] raw; + private final int bodyStart, bodyLength; public MessageImpl(MessageId id, MessageId parent, GroupId group, - AuthorId author, String subject, long timestamp, byte[] raw) { + AuthorId author, String subject, long timestamp, byte[] raw, + int bodyStart, int bodyLength) { + if(bodyStart + bodyLength > raw.length) + throw new IllegalArgumentException(); + if(bodyLength > Message.MAX_BODY_LENGTH) + throw new IllegalArgumentException(); this.id = id; this.parent = parent; this.group = group; @@ -27,6 +30,8 @@ class MessageImpl implements Message { this.subject = subject; this.timestamp = timestamp; this.raw = raw; + this.bodyStart = bodyStart; + this.bodyLength = bodyLength; } public MessageId getId() { @@ -57,12 +62,16 @@ class MessageImpl implements Message { return raw.length; } - public byte[] getSerialisedBytes() { + public byte[] getSerialised() { return raw; } - public InputStream getSerialisedStream() { - return new ByteArrayInputStream(raw); + public int getBodyStart() { + return bodyStart; + } + + public int getBodyLength() { + return bodyLength; } @Override diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java index 52e398a1c4240701640f7a4a6349be0fade00def..1039a42d7b9c581f04880903be59435ee93e1526 100644 --- a/components/net/sf/briar/protocol/MessageReader.java +++ b/components/net/sf/briar/protocol/MessageReader.java @@ -84,8 +84,10 @@ class MessageReader implements ObjectReader<Message> { // Read the salt byte[] salt = r.readBytes(Message.SALT_LENGTH); if(salt.length != Message.SALT_LENGTH) throw new FormatException(); - // Skip the message body - r.readBytes(Message.MAX_BODY_LENGTH); + // Read the message body + byte[] body = r.readBytes(Message.MAX_BODY_LENGTH); + // Record the offset of the body within the message + int bodyStart = (int) counting.getCount() - body.length; // Record the length of the data covered by the author's signature int signedByAuthor = (int) counting.getCount(); // Read the author's signature, if there is one @@ -131,6 +133,6 @@ class MessageReader implements ObjectReader<Message> { GroupId groupId = group == null ? null : group.getId(); AuthorId authorId = author == null ? null : author.getId(); return new MessageImpl(id, parent, groupId, authorId, subject, - timestamp, raw); + timestamp, raw, bodyStart, body.length); } } diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java index acbf2e4922fb23f920134fb2fa9ff2db8a6831f1..0795d27a4e7db5cc20fb1f2df777a21ed0a273f2 100644 --- a/components/net/sf/briar/serial/ReaderImpl.java +++ b/components/net/sf/briar/serial/ReaderImpl.java @@ -272,26 +272,21 @@ class ReaderImpl implements Reader { } public String readString() throws IOException { + return readString(maxStringLength); + } + + public String readString(int maxLength) throws IOException { if(!hasString()) throw new FormatException(); consumeLookahead(); int length; if(next == Tag.STRING) length = readLength(); else length = 0xFF & next ^ Tag.SHORT_STRING; - if(length > maxStringLength) throw new FormatException(); + if(length > maxLength) throw new FormatException(); if(length == 0) return ""; readIntoBuffer(length); return new String(buf, 0, length, "UTF-8"); } - public String readString(int maxLength) throws IOException { - setMaxStringLength(maxLength); - try { - return readString(); - } finally { - resetMaxStringLength(); - } - } - private int readLength() throws IOException { if(!hasLength()) throw new FormatException(); if(next >= 0) return readUint7(); @@ -315,27 +310,22 @@ class ReaderImpl implements Reader { } public byte[] readBytes() throws IOException { + return readBytes(maxBytesLength); + } + + public byte[] readBytes(int maxLength) throws IOException { if(!hasBytes()) throw new FormatException(); consumeLookahead(); int length; if(next == Tag.BYTES) length = readLength(); else length = 0xFF & next ^ Tag.SHORT_BYTES; - if(length > maxBytesLength) throw new FormatException(); + if(length > maxLength) throw new FormatException(); if(length == 0) return EMPTY_BUFFER; byte[] b = new byte[length]; readIntoBuffer(b, length); return b; } - public byte[] readBytes(int maxLength) throws IOException { - setMaxBytesLength(maxLength); - try { - return readBytes(); - } finally { - resetMaxBytesLength(); - } - } - public boolean hasList() throws IOException { if(!hasLookahead) readLookahead(true); if(eof) return false; diff --git a/test/net/sf/briar/ProtocolIntegrationTest.java b/test/net/sf/briar/ProtocolIntegrationTest.java index a33aef927bf3b793946142b351d132674406f9d2..3da633eba2e8ed87cc9f43377cdaa76d489ebfe2 100644 --- a/test/net/sf/briar/ProtocolIntegrationTest.java +++ b/test/net/sf/briar/ProtocolIntegrationTest.java @@ -142,10 +142,10 @@ public class ProtocolIntegrationTest extends TestCase { a.finish(); BatchWriter b = protocolWriterFactory.createBatchWriter(out1); - assertTrue(b.writeMessage(message.getSerialisedBytes())); - assertTrue(b.writeMessage(message1.getSerialisedBytes())); - assertTrue(b.writeMessage(message2.getSerialisedBytes())); - assertTrue(b.writeMessage(message3.getSerialisedBytes())); + assertTrue(b.writeMessage(message.getSerialised())); + assertTrue(b.writeMessage(message1.getSerialised())); + assertTrue(b.writeMessage(message2.getSerialised())); + assertTrue(b.writeMessage(message3.getSerialised())); b.finish(); OfferWriter o = protocolWriterFactory.createOfferWriter(out1); @@ -255,6 +255,6 @@ public class ProtocolIntegrationTest extends TestCase { assertEquals(m1.getGroup(), m2.getGroup()); assertEquals(m1.getAuthor(), m2.getAuthor()); assertEquals(m1.getTimestamp(), m2.getTimestamp()); - assertArrayEquals(m1.getSerialisedBytes(), m2.getSerialisedBytes()); + assertArrayEquals(m1.getSerialised(), m2.getSerialised()); } } diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index 9f26ce75283b28d799daecd11adbf29306f8b840..2e8691ede953e63f1307dd1a41dde83810779490 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -4,6 +4,7 @@ import static org.junit.Assert.assertArrayEquals; import java.io.File; import java.sql.Connection; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -1606,6 +1607,46 @@ public class H2DatabaseTest extends TestCase { db.close(); } + @Test + public void testGetMessageBody() throws Exception { + Database<Connection> db = open(false); + Connection txn = db.startTransaction(); + + // Add a contact and subscribe to a group + assertEquals(contactId, db.addContact(txn, transports, secret)); + db.addSubscription(txn, group); + + // Store a couple of messages + int bodyLength = raw.length - 20; + Message message1 = new TestMessage(messageId, null, groupId, null, + subject, timestamp, raw, 5, bodyLength); + Message privateMessage1 = new TestMessage(privateMessageId, null, null, + null, subject, timestamp, raw, 10, bodyLength); + db.addGroupMessage(txn, message1); + db.addPrivateMessage(txn, privateMessage1, contactId); + + // Calculate the expected message bodies + byte[] expectedBody = new byte[bodyLength]; + System.arraycopy(raw, 5, expectedBody, 0, bodyLength); + assertFalse(Arrays.equals(expectedBody, new byte[bodyLength])); + byte[] expectedBody1 = new byte[bodyLength]; + System.arraycopy(raw, 10, expectedBody1, 0, bodyLength); + System.arraycopy(raw, 10, expectedBody1, 0, bodyLength); + + // Retrieve the raw messages + assertArrayEquals(raw, db.getMessage(txn, messageId)); + assertArrayEquals(raw, db.getMessage(txn, privateMessageId)); + + // Retrieve the message bodies + byte[] body = db.getMessageBody(txn, messageId); + assertArrayEquals(expectedBody, body); + byte[] body1 = db.getMessageBody(txn, privateMessageId); + assertArrayEquals(expectedBody1, body1); + + db.commitTransaction(txn); + db.close(); + } + @Test public void testExceptionHandling() throws Exception { Database<Connection> db = open(false); diff --git a/test/net/sf/briar/db/TestMessage.java b/test/net/sf/briar/db/TestMessage.java index 308469947296730aea221accc0447736870ce48a..42f54e06fdb4d5f86af110926b13676c05bfc236 100644 --- a/test/net/sf/briar/db/TestMessage.java +++ b/test/net/sf/briar/db/TestMessage.java @@ -16,9 +16,16 @@ class TestMessage implements Message { private final String subject; private final long timestamp; private final byte[] raw; + private final int bodyStart, bodyLength; public TestMessage(MessageId id, MessageId parent, GroupId group, AuthorId author, String subject, long timestamp, byte[] raw) { + this(id, parent, group, author, subject, timestamp, raw, 0, raw.length); + } + + public TestMessage(MessageId id, MessageId parent, GroupId group, + AuthorId author, String subject, long timestamp, byte[] raw, + int bodyStart, int bodyLength) { this.id = id; this.parent = parent; this.group = group; @@ -26,6 +33,8 @@ class TestMessage implements Message { this.subject = subject; this.timestamp = timestamp; this.raw = raw; + this.bodyStart = bodyStart; + this.bodyLength = bodyLength; } public MessageId getId() { @@ -56,10 +65,18 @@ class TestMessage implements Message { return raw.length; } - public byte[] getSerialisedBytes() { + public byte[] getSerialised() { return raw; } + public int getBodyStart() { + return bodyStart; + } + + public int getBodyLength() { + return bodyLength; + } + public InputStream getSerialisedStream() { return new ByteArrayInputStream(raw); } diff --git a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java index b749649c679a085152a691cb7a4bf090a4455c23..c502c27b081bea45422513aacc053a6243b430d5 100644 --- a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java +++ b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java @@ -87,7 +87,7 @@ public class ProtocolReadWriteTest extends TestCase { a.finish(); BatchWriter b = writerFactory.createBatchWriter(out); - b.writeMessage(message.getSerialisedBytes()); + b.writeMessage(message.getSerialised()); b.finish(); OfferWriter o = writerFactory.createOfferWriter(out); diff --git a/test/net/sf/briar/protocol/writers/ConstantsTest.java b/test/net/sf/briar/protocol/writers/ConstantsTest.java index 4c966682e75b6696a9556bbf52b56ba1e924213c..372d32dc48d0b551bc328a46350c8b21a1d50798 100644 --- a/test/net/sf/briar/protocol/writers/ConstantsTest.java +++ b/test/net/sf/briar/protocol/writers/ConstantsTest.java @@ -113,7 +113,7 @@ public class ConstantsTest extends TestCase { ProtocolConstants.MAX_PACKET_LENGTH); BatchWriter b = new BatchWriterImpl(out, serial, writerFactory, crypto.getMessageDigest()); - assertTrue(b.writeMessage(message.getSerialisedBytes())); + assertTrue(b.writeMessage(message.getSerialised())); b.finish(); // Check the size of the serialised batch assertTrue(out.size() > UniqueId.LENGTH + Group.MAX_NAME_LENGTH +