001/*
002 * Copyright (c) 2014, 2014, Oracle and/or its affiliates. All rights reserved.
003 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
004 *
005 * This code is free software; you can redistribute it and/or modify it
006 * under the terms of the GNU General Public License version 2 only, as
007 * published by the Free Software Foundation.
008 *
009 * This code is distributed in the hope that it will be useful, but WITHOUT
010 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
011 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
012 * version 2 for more details (a copy is included in the LICENSE file that
013 * accompanied this code).
014 *
015 * You should have received a copy of the GNU General Public License version
016 * 2 along with this work; if not, write to the Free Software Foundation,
017 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
018 *
019 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
020 * or visit www.oracle.com if you need additional information or have any
021 * questions.
022 */
023package com.oracle.graal.jtt.except;
024
025import jdk.internal.org.objectweb.asm.*;
026
027import org.junit.*;
028
029import com.oracle.graal.jtt.*;
030
031public class UntrustedInterfaces extends JTTTest {
032
033    public interface CallBack {
034        int callBack(TestInterface ti);
035    }
036
037    private interface TestInterface {
038        int method();
039    }
040
041    /**
042     * What a GoodPill would look like.
043     *
044     * <pre>
045     * private static final class GoodPill extends Pill {
046     *     public void setField() {
047     *         field = new TestConstant();
048     *     }
049     *
050     *     public void setStaticField() {
051     *         staticField = new TestConstant();
052     *     }
053     *
054     *     public int callMe(CallBack callback) {
055     *         return callback.callBack(new TestConstant());
056     *     }
057     *
058     *     public TestInterface get() {
059     *         return new TestConstant();
060     *     }
061     * }
062     *
063     * private static final class TestConstant implements TestInterface {
064     *     public int method() {
065     *         return 42;
066     *     }
067     * }
068     * </pre>
069     */
070    public abstract static class Pill {
071        public static TestInterface staticField;
072        public TestInterface field;
073
074        public abstract void setField();
075
076        public abstract void setStaticField();
077
078        public abstract int callMe(CallBack callback);
079
080        public abstract TestInterface get();
081    }
082
083    public int callBack(TestInterface list) {
084        return list.method();
085    }
086
087    public int staticFieldInvoke(Pill pill) {
088        pill.setStaticField();
089        return Pill.staticField.method();
090    }
091
092    public int fieldInvoke(Pill pill) {
093        pill.setField();
094        return pill.field.method();
095    }
096
097    public int argumentInvoke(Pill pill) {
098        return pill.callMe(ti -> ti.method());
099    }
100
101    public int returnInvoke(Pill pill) {
102        return pill.get().method();
103    }
104
105    @SuppressWarnings("cast")
106    public boolean staticFieldInstanceof(Pill pill) {
107        pill.setStaticField();
108        return Pill.staticField instanceof TestInterface;
109    }
110
111    @SuppressWarnings("cast")
112    public boolean fieldInstanceof(Pill pill) {
113        pill.setField();
114        return pill.field instanceof TestInterface;
115    }
116
117    @SuppressWarnings("cast")
118    public int argumentInstanceof(Pill pill) {
119        return pill.callMe(ti -> ti instanceof TestInterface ? 42 : 24);
120    }
121
122    @SuppressWarnings("cast")
123    public boolean returnInstanceof(Pill pill) {
124        return pill.get() instanceof TestInterface;
125    }
126
127    public TestInterface staticFieldCheckcast(Pill pill) {
128        pill.setStaticField();
129        return TestInterface.class.cast(Pill.staticField);
130    }
131
132    public TestInterface fieldCheckcast(Pill pill) {
133        pill.setField();
134        return TestInterface.class.cast(pill.field);
135    }
136
137    public int argumentCheckcast(Pill pill) {
138        return pill.callMe(ti -> TestInterface.class.cast(ti).method());
139    }
140
141    public TestInterface returnCheckcast(Pill pill) {
142        return TestInterface.class.cast(pill.get());
143    }
144
145    private static Pill poisonPill;
146
147    // Checkstyle: stop
148    @BeforeClass
149    public static void setUp() throws InstantiationException, IllegalAccessException, ClassNotFoundException {
150        poisonPill = (Pill) new PoisonLoader().findClass(PoisonLoader.POISON_IMPL_NAME).newInstance();
151    }
152
153    // Checkstyle: resume
154
155    @Test
156    public void testStaticField0() {
157        runTest("staticFieldInvoke", poisonPill);
158    }
159
160    @Test
161    public void testStaticField1() {
162        runTest("staticFieldInstanceof", poisonPill);
163    }
164
165    @Test
166    public void testStaticField2() {
167        runTest("staticFieldCheckcast", poisonPill);
168    }
169
170    @Test
171    public void testField0() {
172        runTest("fieldInvoke", poisonPill);
173    }
174
175    @Test
176    public void testField1() {
177        runTest("fieldInstanceof", poisonPill);
178    }
179
180    @Test
181    public void testField2() {
182        runTest("fieldCheckcast", poisonPill);
183    }
184
185    @Test
186    public void testArgument0() {
187        runTest("argumentInvoke", poisonPill);
188    }
189
190    @Test
191    public void testArgument1() {
192        runTest("argumentInstanceof", poisonPill);
193    }
194
195    @Test
196    public void testArgument2() {
197        runTest("argumentCheckcast", poisonPill);
198    }
199
200    @Test
201    public void testReturn0() {
202        runTest("returnInvoke", poisonPill);
203    }
204
205    @Test
206    public void testReturn1() {
207        runTest("returnInstanceof", poisonPill);
208    }
209
210    @Test
211    public void testReturn2() {
212        runTest("returnCheckcast", poisonPill);
213    }
214
215    private static class PoisonLoader extends ClassLoader {
216        public static final String POISON_IMPL_NAME = "com.oracle.graal.jtt.except.PoisonPill";
217
218        @Override
219        protected Class<?> findClass(String name) throws ClassNotFoundException {
220            if (name.equals(POISON_IMPL_NAME)) {
221                ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
222
223                cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, POISON_IMPL_NAME.replace('.', '/'), null, Type.getInternalName(Pill.class), null);
224                // constructor
225                MethodVisitor constructor = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "()V", null, null);
226                constructor.visitCode();
227                constructor.visitVarInsn(Opcodes.ALOAD, 0);
228                constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Pill.class), "<init>", "()V", false);
229                constructor.visitInsn(Opcodes.RETURN);
230                constructor.visitMaxs(0, 0);
231                constructor.visitEnd();
232
233                MethodVisitor setList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setField", "()V", null, null);
234                setList.visitCode();
235                setList.visitVarInsn(Opcodes.ALOAD, 0);
236                setList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
237                setList.visitInsn(Opcodes.DUP);
238                setList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
239                setList.visitFieldInsn(Opcodes.PUTFIELD, Type.getInternalName(Pill.class), "field", Type.getDescriptor(TestInterface.class));
240                setList.visitInsn(Opcodes.RETURN);
241                setList.visitMaxs(0, 0);
242                setList.visitEnd();
243
244                MethodVisitor setStaticList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setStaticField", "()V", null, null);
245                setStaticList.visitCode();
246                setStaticList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
247                setStaticList.visitInsn(Opcodes.DUP);
248                setStaticList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
249                setStaticList.visitFieldInsn(Opcodes.PUTSTATIC, Type.getInternalName(Pill.class), "staticField", Type.getDescriptor(TestInterface.class));
250                setStaticList.visitInsn(Opcodes.RETURN);
251                setStaticList.visitMaxs(0, 0);
252                setStaticList.visitEnd();
253
254                MethodVisitor callMe = cw.visitMethod(Opcodes.ACC_PUBLIC, "callMe", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(CallBack.class)), null, null);
255                callMe.visitCode();
256                callMe.visitVarInsn(Opcodes.ALOAD, 1);
257                callMe.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
258                callMe.visitInsn(Opcodes.DUP);
259                callMe.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
260                callMe.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(CallBack.class), "callBack", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(TestInterface.class)), true);
261                callMe.visitInsn(Opcodes.IRETURN);
262                callMe.visitMaxs(0, 0);
263                callMe.visitEnd();
264
265                MethodVisitor getList = cw.visitMethod(Opcodes.ACC_PUBLIC, "get", Type.getMethodDescriptor(Type.getType(TestInterface.class)), null, null);
266                getList.visitCode();
267                getList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
268                getList.visitInsn(Opcodes.DUP);
269                getList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
270                getList.visitInsn(Opcodes.ARETURN);
271                getList.visitMaxs(0, 0);
272                getList.visitEnd();
273
274                cw.visitEnd();
275
276                byte[] bytes = cw.toByteArray();
277                return defineClass(name, bytes, 0, bytes.length);
278            }
279            return super.findClass(name);
280        }
281    }
282}