From 45b4bef348808b50595394de25a0265bd722715c Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Wed, 20 Jul 2011 15:07:17 +0100
Subject: [PATCH] Catch ClassCastException when the encountered type doesn't
 match the expected type, and re-throw as FormatException.

---
 api/net/sf/briar/api/serial/Reader.java       |  2 +-
 .../sf/briar/protocol/BundleReaderImpl.java   |  4 +-
 .../net/sf/briar/serial/ReaderImpl.java       |  9 ++--
 test/net/sf/briar/serial/ReaderImplTest.java  | 42 ++++++++++++++++++-
 4 files changed, 47 insertions(+), 10 deletions(-)

diff --git a/api/net/sf/briar/api/serial/Reader.java b/api/net/sf/briar/api/serial/Reader.java
index 3bf5b6a278..b1b7997d95 100644
--- a/api/net/sf/briar/api/serial/Reader.java
+++ b/api/net/sf/briar/api/serial/Reader.java
@@ -66,6 +66,6 @@ public interface Reader {
 	boolean hasUserDefinedTag() throws IOException;
 	int readUserDefinedTag() throws IOException;
 	void readUserDefinedTag(int tag) throws IOException;
-	<T> T readUserDefinedObject(int tag) throws IOException,
+	<T> T readUserDefinedObject(int tag, Class<T> t) throws IOException,
 	GeneralSecurityException;
 }
diff --git a/components/net/sf/briar/protocol/BundleReaderImpl.java b/components/net/sf/briar/protocol/BundleReaderImpl.java
index 42c47cc601..db12a3cd9b 100644
--- a/components/net/sf/briar/protocol/BundleReaderImpl.java
+++ b/components/net/sf/briar/protocol/BundleReaderImpl.java
@@ -31,7 +31,7 @@ class BundleReaderImpl implements BundleReader {
 		if(state != State.START) throw new IllegalStateException();
 		reader.addObjectReader(Tags.HEADER, headerReader);
 		reader.readUserDefinedTag(Tags.HEADER);
-		Header h = reader.readUserDefinedObject(Tags.HEADER);
+		Header h = reader.readUserDefinedObject(Tags.HEADER, Header.class);
 		reader.removeObjectReader(Tags.HEADER);
 		state = State.FIRST_BATCH;
 		return h;
@@ -53,7 +53,7 @@ class BundleReaderImpl implements BundleReader {
 			return null;
 		}
 		reader.readUserDefinedTag(Tags.BATCH);
-		return reader.readUserDefinedObject(Tags.BATCH);
+		return reader.readUserDefinedObject(Tags.BATCH, Batch.class);
 	}
 
 	public void finish() throws IOException {
diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java
index dfc2b45243..be019665f9 100644
--- a/components/net/sf/briar/serial/ReaderImpl.java
+++ b/components/net/sf/briar/serial/ReaderImpl.java
@@ -386,11 +386,10 @@ class ReaderImpl implements Reader {
 		throw new FormatException();
 	}
 
-	@SuppressWarnings("unchecked")
 	private <T> T readObject(Class<T> t) throws IOException,
 	GeneralSecurityException {
 		try {
-			return (T) readObject();
+			return t.cast(readObject());
 		} catch(ClassCastException e) {
 			throw new FormatException();
 		}
@@ -507,14 +506,12 @@ class ReaderImpl implements Reader {
 		if(readUserDefinedTag() != tag) throw new FormatException();
 	}
 
-	public <T> T readUserDefinedObject(int tag) throws IOException,
+	public <T> T readUserDefinedObject(int tag, Class<T> t) throws IOException,
 	GeneralSecurityException {
 		ObjectReader<?> o = objectReaders.get(tag);
 		if(o == null) throw new FormatException();
 		try {
-			@SuppressWarnings("unchecked")
-			ObjectReader<T> cast = (ObjectReader<T>) o;
-			return cast.readObject(this);
+			return t.cast(o.readObject(this));
 		} catch(ClassCastException e) {
 			throw new FormatException();
 		}
diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java
index 83e8366963..33f2472086 100644
--- a/test/net/sf/briar/serial/ReaderImplTest.java
+++ b/test/net/sf/briar/serial/ReaderImplTest.java
@@ -8,6 +8,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 
 import junit.framework.TestCase;
+import net.sf.briar.api.serial.FormatException;
 import net.sf.briar.api.serial.ObjectReader;
 import net.sf.briar.api.serial.Raw;
 import net.sf.briar.api.serial.RawByteArray;
@@ -176,6 +177,17 @@ public class ReaderImplTest extends TestCase {
 		assertTrue(r.eof());
 	}
 
+	@Test
+	public void testReadListTypeSafeThrowsFormatException() throws Exception {
+		setContents("A" + "3" + "01" + "83666F6F" + "03");
+		// Trying to read a mixed list as a list of bytes should throw a
+		// FormatException
+		try {
+			r.readList(Byte.class);
+			assertTrue(false);
+		} catch(FormatException expected) {}
+	}
+
 	@Test
 	public void testReadShortMap() throws Exception {
 		setContents("B" + "2" + "83666F6F" + "7B" + "90" + "F0");
@@ -338,7 +350,35 @@ public class ReaderImplTest extends TestCase {
 			}
 		});
 		assertEquals(0, r.readUserDefinedTag());
-		assertEquals("foo", r.<Foo>readUserDefinedObject(0).s);
+		assertEquals("foo", r.readUserDefinedObject(0, Foo.class).s);
+	}
+
+	@Test
+	public void testUnknownTagThrowsFormatException() throws Exception {
+		setContents("C0" + "83666F6F");
+		// No object reader has been added for tag 0
+		assertEquals(0, r.readUserDefinedTag());
+		try {
+			r.readUserDefinedObject(0, Foo.class);
+			assertTrue(false);
+		} catch(FormatException expected) {}
+	}
+
+	@Test
+	public void testWrongClassThrowsFormatException() throws Exception {
+		setContents("C0" + "83666F6F");
+		// Add an object reader for tag 0, class Foo
+		r.addObjectReader(0, new ObjectReader<Foo>() {
+			public Foo readObject(Reader r) throws IOException {
+				return new Foo(r.readString());
+			}
+		});
+		assertEquals(0, r.readUserDefinedTag());
+		// Trying to read the object as class Bar should throw a FormatException
+		try {
+			r.readUserDefinedObject(0, Bar.class);
+			assertTrue(false);
+		} catch(FormatException expected) {}
 	}
 
 	@Test
-- 
GitLab