Source code for visitors.return_instrumentation

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Example visitor that adds a return label before every expression in a return statement.
The purpose is to identify the binary code used to return from a function.
If the function returns ``void``, a label and an explicit return function is added at the end of the function."""


from functools import singledispatch
from core.utils import print_if_verbose
from typing import List
from core import ast

_return_label_counter = 0


[docs]def generate_label() -> str: """Generates a new unique label string""" global _return_label_counter _return_label_counter += 1 return f"__RETURN{_return_label_counter}__"
[docs]@singledispatch def visit(node): raise TypeError("Unknown node type: " + node.__class__.__name__)
@visit.register(ast.Program) def _(program: ast.Program): print_if_verbose("") print_if_verbose("*" * 80) print_if_verbose("* RETURN INSTRUMENTATION") print_if_verbose("*" * 80) for function in program.functions: visit(function) visit(program.main) def _instrument_statements(statements: List[ast.ASTNode]) -> List[ast.ASTNode]: """Includes a label before the return statements""" instrumented_stmts = [] for stmt in statements: if isinstance(stmt, ast.Return): print_if_verbose(f"Return stmt: {stmt.to_str()}") label_ast = ast.Label(generate_label()) instrumented_stmts.append(label_ast) visit(stmt) instrumented_stmts.append(stmt) return instrumented_stmts @visit.register(ast.Function) def _(function: ast.Function): """Traverses a function definition to add a RETURN label before each return statement""" function.stmts = _instrument_statements(function.stmts) # If is a procedure add at the end a return statement if isinstance(function.return_type, ast.Void): function.stmts.append(ast.Label(generate_label())) function.stmts.append(ast.Return()) @visit.register(ast.Block) @visit.register(ast.Do) @visit.register(ast.While) @visit.register(ast.For) def _(node): """Traverses a control flow statement to add a RETURN label before each return statement""" node.statements = _instrument_statements(node.statements) @visit.register(ast.If) def _(node: ast.If): # include RETURN label before return statements in if and else bodies node.if_statements = _instrument_statements(node.if_statements) node.else_statements = _instrument_statements(node.else_statements) @visit.register(ast.Switch) def _(node: ast.Switch): # include RETURN label before return statements in case and default blocks for case_literal, case_statements in node.cases.items(): node.cases[case_literal] = _instrument_statements(node.cases[case_literal]) node.default = _instrument_statements(node.default) @visit.register(ast.Assignment) @visit.register(ast.UnaryExpression) @visit.register(ast.Return) @visit.register(ast.Invocation) @visit.register(ast.Variable) @visit.register(ast.Literal) @visit.register(ast.TernaryASTNode) @visit.register(ast.BinaryASTNode) @visit.register(ast.UnaryASTNode) @visit.register(ast.Break) @visit.register(ast.Label) def _(statement): pass