Source code for visitors.function_subs

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Visitor to replace the invocations to functions that have been removed with the controlled function
generation functionality (i.e., ``-f`` or ``--functions`` options)."""

from functools import singledispatch
from core import ast, generators, probs
from params.parameters import get_app_args
from typing import List, Union


[docs]@singledispatch def visit(node, targets: List[str], function: ast.Function, is_statement: bool): raise TypeError("Unknown node type: " + node.__class__.__name__)
_program = None @visit.register(ast.Program) def _(program: ast.Program, targets: List[str], function: ast.Function = None, is_statement: bool = True) -> ast.Program: # Pointers, Struct and Arrays only generate literals in declarations. The normal solution is to # generate a lvalue instead, this solution can cause the creation of new functions. The solution to this new problem # is to eliminate the probability of invocation before generate a new literal. global _program _program = program new_probs = {False: 1, True: 0} old_probs = dict(probs.call_prob) probs.call_prob = new_probs for function in program.functions: visit(function, targets, function, True) visit(program.main, targets, program.main, True) # Restore the old probability probs.call_prob = dict(old_probs) return program @visit.register(ast.Function) def _(node: ast.Function, targets: List[str], function: ast.Function, is_statement: bool) -> ast.Function: node.children = _visit_ast_nodes(node.stmts, targets, node, True) return node @visit.register(ast.Invocation) def _(invocation: ast.Invocation, targets: List[str], function: ast.Function, is_statement: bool) \ -> Union[ast.ASTNode, None]: if invocation.func_name in targets and not isinstance(invocation.return_type, ast.Void): # If invocation is a statement if is_statement: if get_app_args().verbose: print("{}: Remove call {}".format(function.name, invocation.func_name)) return None # If invocation is an expression else: literal = generators.generate_literal( _program, function, invocation.return_type, from_declaration=False) if get_app_args().verbose: print("{} Subs call {} -> ({}) {}".format( function.name, invocation.func_name, invocation.return_type.name, literal.to_str())) return literal else: invocation.arguments = _visit_ast_nodes(invocation.arguments, targets, function, False) return invocation @visit.register(ast.Assignment) @visit.register(ast.BinaryExpression) @visit.register(ast.ArrayAccessExpression) @visit.register(ast.UnaryExpression) @visit.register(ast.Return) @visit.register(ast.StructAccessExpression) @visit.register(ast.TernaryExpression) def _(node, targets: List[str], function: ast.Function, is_statement: bool) -> ast.ASTNode: node.children = _visit_ast_nodes(node.children, targets, function, False) return node @visit.register(ast.Block) def _(node: ast.Block, targets: List[str], function: ast.Function, is_statement: bool) -> ast.Block: node.statements = _visit_ast_nodes(node.statements, targets, function, True) return node @visit.register(ast.Do) @visit.register(ast.While) def _(node, targets: List[str], function: ast.Function, is_statement: bool) -> ast.ASTNode: node.condition = visit(node.condition, targets, function, False) node.statements = _visit_ast_nodes(node.statements, targets, function, True) return node @visit.register(ast.If) def _(node: ast.If, targets: List[str], function: ast.Function, is_statement: bool) -> ast.If: node.condition = visit(node.condition, targets, function, False) node.if_statements = _visit_ast_nodes(node.if_statements, targets, function, True) node.else_statements = _visit_ast_nodes(node.else_statements, targets, function, True) return node @visit.register(ast.For) def _(node: ast.For, targets: List[str], function: ast.Function, is_statement: bool) -> ast.For: node.initialization = visit(node.initialization, targets, function, False) node.condition = visit(node.condition, targets, function, False) node.increment = visit(node.increment, targets, function, False) node.statements = _visit_ast_nodes(node.statements, targets, function, True) return node @visit.register(ast.Switch) def _(node: ast.Switch, targets: List[str], function: ast.Function, is_statement: bool) -> ast.Switch: node.condition = visit(node.condition, targets, function, False) for case_literal, case_statements in node.cases.items(): node.cases[case_literal] = _visit_ast_nodes(case_statements, targets, function, True) node.default = _visit_ast_nodes(node.default, targets, function, True) return node @visit.register(ast.Literal) @visit.register(ast.Variable) @visit.register(ast.Break) @visit.register(ast.Label) def _(node, targets: List[str], function: ast.Function, is_statement: bool) -> ast.ASTNode: return node def _visit_ast_nodes(nodes: list, targets: List[str], function: ast.Function, is_statement: bool) -> list: """Traverses the nodes and returns the new ones""" # the statement could be a type, and then it is not traversed (visited) result_nodes = list() for node in nodes: if isinstance(node, ast.ASTNode): # it is an ASTNode (not a type) new_node = visit(node, targets, function, is_statement) if new_node is not None: # this visitor removes invocations, so there might be stmts removed result_nodes.append(new_node) else: # not an ASTNode result_nodes.append(node) return result_nodes