diff --git a/src/rtruffle/abstract_node.py b/src/rtruffle/abstract_node.py index 0ef3e080..0827fea3 100644 --- a/src/rtruffle/abstract_node.py +++ b/src/rtruffle/abstract_node.py @@ -62,7 +62,8 @@ def _adapt_after_inlining(node, mgenc): n.adapt_after_inlining(mgenc) else: current = getattr(node, child_slot) - current.adapt_after_inlining(mgenc) + if current is not None: + current.adapt_after_inlining(mgenc) node.handle_inlining(mgenc) cls.adapt_after_inlining = _adapt_after_inlining @@ -83,7 +84,10 @@ def _adapt_after_outer_inlined(node, removed_ctx_level, mgenc_with_inlined): ) else: current = getattr(node, child_slot) - current.adapt_after_outer_inlined(removed_ctx_level, mgenc_with_inlined) + if current is not None: + current.adapt_after_outer_inlined( + removed_ctx_level, mgenc_with_inlined + ) node.handle_outer_inlined(removed_ctx_level, mgenc_with_inlined) cls.adapt_after_outer_inlined = _adapt_after_outer_inlined diff --git a/src/som/compiler/ast/parser.py b/src/som/compiler/ast/parser.py index 5950bd5f..68cebeac 100644 --- a/src/som/compiler/ast/parser.py +++ b/src/som/compiler/ast/parser.py @@ -25,6 +25,7 @@ IfInlinedNode, IfElseInlinedNode, ) +from som.interpreter.ast.nodes.specialized.literal_to_do import ToDoInlined from som.interpreter.ast.nodes.specialized.literal_while import WhileInlinedNode from som.vm.symbols import symbol_for @@ -72,7 +73,7 @@ def _create_sequence_node(self, coordinate, expressions): if not expressions: from som.vm.globals import nilObject - return LiteralNode(nilObject) + return LiteralNode(nilObject, self._get_source_section(coordinate)) if len(expressions) == 1: return expressions[0] @@ -291,6 +292,18 @@ def _try_inlining_or(receiver, arg_expr, source, mgenc): arg_body = arg_expr.get_method().inline(mgenc) return OrInlinedNode(receiver, arg_body, source) + @staticmethod + def _try_inlining_to_do(receiver, arguments, source, mgenc): + if not isinstance(arguments[1], BlockNode): + return None + + block_method = arguments[1].get_method() + do_expr = block_method.inline(mgenc) + idx_arg = block_method.get_argument(1, 0) + return ToDoInlined( + receiver, arguments[0], do_expr, mgenc.get_inlined_local(idx_arg, 0), source + ) + def _keyword_message(self, mgenc, receiver): is_super_send = self._super_send @@ -358,6 +371,12 @@ def _keyword_message(self, mgenc, receiver): ) if inlined is not None: return inlined + elif keyword == "to:do:": + inlined = self._try_inlining_to_do( + receiver, arguments, source, mgenc + ) + if inlined is not None: + return inlined selector = symbol_for(keyword) @@ -400,8 +419,7 @@ def _literal(self): coord = self._lexer.get_source_coordinate() val = self._get_object_for_current_literal() - lit = LiteralNode(val) - self._assign_source(lit, coord) + lit = LiteralNode(val, self._get_source_section(coord)) return lit def _get_object_for_current_literal(self): diff --git a/src/som/compiler/ast/variable.py b/src/som/compiler/ast/variable.py index e54dd5a3..28040b35 100644 --- a/src/som/compiler/ast/variable.py +++ b/src/som/compiler/ast/variable.py @@ -147,7 +147,8 @@ def get_initialized_read_node(self, context_level, source_section): def copy_for_inlining(self, idx): if self._name == "$blockSelf": return None - return Argument(self._name, idx, self.source) + assert self._name != "self" + return Local(self._name, idx, self.source) def __str__(self): return "Argument(" + self._name + " idx: " + str(self.idx) + ")" diff --git a/src/som/compiler/bc/bytecode_generator.py b/src/som/compiler/bc/bytecode_generator.py index d8a83817..642a550b 100644 --- a/src/som/compiler/bc/bytecode_generator.py +++ b/src/som/compiler/bc/bytecode_generator.py @@ -24,6 +24,10 @@ def emit_push_argument(mgenc, idx, ctx): emit3(mgenc, BC.push_argument, idx, ctx, 1) +def emit_nil_local(mgenc, idx): + emit2(mgenc, BC.nil_local, idx, 0) + + def emit_return_self(mgenc): mgenc.optimize_dup_pop_pop_sequence() emit1(mgenc, BC.return_self, 0) @@ -57,6 +61,10 @@ def emit_dup(mgenc): emit1(mgenc, BC.dup, 1) +def emit_dup_second(mgenc): + emit1(mgenc, BC.dup_second, 1) + + def emit_push_block(mgenc, block_method, with_ctx): idx = mgenc.add_literal_if_absent(block_method) emit2(mgenc, BC.push_block if with_ctx else BC.push_block_no_ctx, idx, 1) @@ -211,6 +219,13 @@ def emit_jump_with_dummy_offset(mgenc): return idx +def emit_jump_if_greater_with_dummy_offset(mgenc): + emit1(mgenc, BC.jump_if_greater, 0) + idx = mgenc.add_bytecode_argument_and_get_index(0) + mgenc.add_bytecode_argument(0) + return idx + + def emit_jump_backward_with_offset(mgenc, offset): emit3( mgenc, diff --git a/src/som/compiler/bc/disassembler.py b/src/som/compiler/bc/disassembler.py index 67df56f3..42360b8d 100644 --- a/src/som/compiler/bc/disassembler.py +++ b/src/som/compiler/bc/disassembler.py @@ -78,7 +78,13 @@ def dump_bytecode(m, b, indent=""): + ", context " + str(m.get_bytecode(b + 2)) ) - elif bytecode == Bytecodes.push_frame or bytecode == Bytecodes.pop_frame: + elif ( + bytecode == Bytecodes.push_frame + or bytecode == Bytecodes.pop_frame + or bytecode == Bytecodes.nil_inner + or bytecode == Bytecodes.nil_frame + or bytecode == Bytecodes.nil_local + ): error_println("idx: " + str(m.get_bytecode(b + 1))) elif bytecode == Bytecodes.push_inner or bytecode == Bytecodes.pop_inner: error_println( diff --git a/src/som/compiler/bc/method_generation_context.py b/src/som/compiler/bc/method_generation_context.py index 4114d76c..47f06741 100644 --- a/src/som/compiler/bc/method_generation_context.py +++ b/src/som/compiler/bc/method_generation_context.py @@ -5,6 +5,12 @@ emit_pop, emit_push_constant, emit_jump_backward_with_offset, + emit_dup, + emit_inc, + emit_dup_second, + emit_jump_if_greater_with_dummy_offset, + emit_pop_local, + emit_nil_local, emit_inc_field_push, emit_return_field, ) @@ -96,8 +102,8 @@ def add_local(self, local_name, source, parser): self._local_list.append(local) return local - def inline_locals(self, local_vars): - fresh_copies = MethodGenerationContextBase.inline_locals(self, local_vars) + def inline_as_locals(self, variables): + fresh_copies = MethodGenerationContextBase.inline_as_locals(self, variables) if fresh_copies: self._local_list.extend(fresh_copies) return fresh_copies @@ -476,16 +482,18 @@ def _assemble_literal_return(self, return_candidate, push_candidate): ): return None + source_section = self.lexical_scope.arguments[0].source + if len(self._literals) == 1: - return LiteralReturn(self.signature, self._literals[0]) + return LiteralReturn(self.signature, self._literals[0], source_section) if self._bytecode[0] == Bytecodes.push_0: - return LiteralReturn(self.signature, int_0) + return LiteralReturn(self.signature, int_0, source_section) if self._bytecode[0] == Bytecodes.push_1: - return LiteralReturn(self.signature, int_1) + return LiteralReturn(self.signature, int_1, source_section) if self._bytecode[0] == Bytecodes.push_nil: from som.vm.globals import nilObject - return LiteralReturn(self.signature, nilObject) + return LiteralReturn(self.signature, nilObject, source_section) raise NotImplementedError( "Not sure what's going on. Perhaps some new bytecode or unexpected literal?" ) @@ -502,14 +510,16 @@ def _assemble_global_return(self, return_candidate, push_candidate): global_name = self._literals[0] assert isinstance(global_name, Symbol) + source_section = self.lexical_scope.arguments[0].source + if global_name is sym_true: - return LiteralReturn(self.signature, trueObject) + return LiteralReturn(self.signature, trueObject, source_section) if global_name is sym_false: - return LiteralReturn(self.signature, falseObject) + return LiteralReturn(self.signature, falseObject, source_section) if global_name is sym_nil: from som.vm.globals import nilObject - return LiteralReturn(self.signature, nilObject) + return LiteralReturn(self.signature, nilObject, source_section) return GlobalRead( self.signature, @@ -752,6 +762,52 @@ def inline_andor(self, parser, is_or): return True + def inline_to_do(self, parser): + # HACK: We do assume that the receiver on the stack is a integer, + # HACK: similar to the other inlined messages. + # HACK: We don't support anything but integer at the moment. + push_block_candidate = self._last_bytecode_is_one_of(0, PUSH_BLOCK_BYTECODES) + if push_block_candidate == Bytecodes.invalid: + return False + + assert bytecode_length(push_block_candidate) == 2 + block_literal_idx = self._bytecode[-1] + + to_be_inlined = self._literals[block_literal_idx] + to_be_inlined.merge_scope_into(self) + + block_arg = to_be_inlined.get_argument(1, 0) + i_var_idx = self.get_inlined_local_idx(block_arg, 0) + + self._remove_last_bytecodes(1) # remove push_block* + + self._is_currently_inlining_a_block = True + emit_dup_second(self) + + emit_nil_local(self, i_var_idx) + + loop_begin_idx = self.offset_of_next_instruction() + jump_offset_idx_to_end = emit_jump_if_greater_with_dummy_offset(self) + + emit_dup(self) + + emit_pop_local(self, i_var_idx, 0) + + to_be_inlined.inline(self, False) + + emit_pop(self) + emit_inc(self) + + emit_nil_local(self, i_var_idx) + self.emit_backwards_jump_offset_to_target(loop_begin_idx, parser) + + self.patch_jump_offset_to_point_to_next_instruction( + jump_offset_idx_to_end, parser + ) + + self._is_currently_inlining_a_block = False + return True + def _complete_jumps_and_emit_returning_nil( self, parser, loop_begin_idx, jump_offset_idx_to_skip_loop_body ): diff --git a/src/som/compiler/bc/parser.py b/src/som/compiler/bc/parser.py index bfb3a6aa..bac9d2cd 100644 --- a/src/som/compiler/bc/parser.py +++ b/src/som/compiler/bc/parser.py @@ -290,6 +290,7 @@ def _keyword_message(self, mgenc): keyword == "ifFalse:ifTrue:" and mgenc.inline_if_true_false(self, False) ) + or (keyword == "to:do:" and mgenc.inline_to_do(self)) ): return diff --git a/src/som/compiler/method_generation_context.py b/src/som/compiler/method_generation_context.py index e8459e6e..e1f5033a 100644 --- a/src/som/compiler/method_generation_context.py +++ b/src/som/compiler/method_generation_context.py @@ -83,13 +83,13 @@ def add_local(self, local_name, source, parser): self._locals[local_name] = result return result - def inline_locals(self, local_vars): + def inline_as_locals(self, variables): fresh_copies = [] - for local in local_vars: - fresh_copy = local.copy_for_inlining(len(self._locals)) + for var in variables: + fresh_copy = var.copy_for_inlining(len(self._locals)) if fresh_copy: # fresh_copy can be None, because we don't need the $blockSelf - name = local.get_qualified_name() + name = var.get_qualified_name() assert name not in self._locals self._locals[name] = fresh_copy fresh_copies.append(fresh_copy) @@ -212,10 +212,13 @@ def set_block_signature(self, line, column): self.signature = symbol_for(block_sig) def merge_into_scope(self, scope_to_be_inlined): - assert len(scope_to_be_inlined.arguments) == 1 + arg_vars = scope_to_be_inlined.arguments + if len(arg_vars) > 1: + self.inline_as_locals(arg_vars) + local_vars = scope_to_be_inlined.locals if local_vars: - self.inline_locals(local_vars) + self.inline_as_locals(local_vars) def _strip_colons_and_source_location(method_name): diff --git a/src/som/interpreter/ast/nodes/literal_node.py b/src/som/interpreter/ast/nodes/literal_node.py index cbd4b2eb..74b9d8ed 100644 --- a/src/som/interpreter/ast/nodes/literal_node.py +++ b/src/som/interpreter/ast/nodes/literal_node.py @@ -4,7 +4,7 @@ class LiteralNode(ExpressionNode): _immutable_fields_ = ["_value"] - def __init__(self, value, source_section=None): + def __init__(self, value, source_section): ExpressionNode.__init__(self, source_section) self._value = value @@ -14,4 +14,4 @@ def execute(self, _frame): def create_trivial_method(self, signature): from som.vmobjects.method_trivial import LiteralReturn - return LiteralReturn(signature, self._value) + return LiteralReturn(signature, self._value, self.source_section) diff --git a/src/som/interpreter/ast/nodes/specialized/literal_to_do.py b/src/som/interpreter/ast/nodes/specialized/literal_to_do.py new file mode 100644 index 00000000..d3e97f57 --- /dev/null +++ b/src/som/interpreter/ast/nodes/specialized/literal_to_do.py @@ -0,0 +1,100 @@ +from rlib.jit import JitDriver, we_are_jitted + +from som.interpreter.ast.nodes.expression_node import ExpressionNode +from som.vm.globals import nilObject +from som.vmobjects.double import Double +from som.vmobjects.integer import Integer + + +def get_printable_location(self): + assert isinstance(self, ToDoInlined) + source = self.source_section + return "#to:do: %s:%d:%d" % ( + source.file, + source.coord.start_line, + source.coord.start_column, + ) + + +driver_int = JitDriver( + greens=["self"], + reds="auto", + is_recursive=True, + get_printable_location=get_printable_location, +) + +driver_double = JitDriver( + greens=["self"], + reds="auto", + is_recursive=True, + get_printable_location=get_printable_location, +) + + +class ToDoInlined(ExpressionNode): + _immutable_fields_ = [ + "_from_expr?", + "_to_expr?", + "_do_expr?", + "_idx_var", + "_idx_write?", + ] + _child_nodes_ = ["_from_expr", "_to_expr", "_do_expr", "_idx_write"] + + def __init__(self, from_expr, to_expr, do_expr, idx_var, source_section): + ExpressionNode.__init__(self, source_section) + self._from_expr = self.adopt_child(from_expr) + self._to_expr = self.adopt_child(to_expr) + self._do_expr = self.adopt_child(do_expr) + self._idx_var = idx_var + self._idx_write = self.adopt_child(idx_var.get_write_node(0, None)) + + def execute(self, frame): + start = self._from_expr.execute(frame) + end = self._to_expr.execute(frame) + + if isinstance(start, Double): + return self._execute_double(frame, start, end) + + assert isinstance(start, Integer) + + if isinstance(end, Integer): + end_int = end.get_embedded_integer() + else: + assert isinstance(end, Double) + end_int = int(end.get_embedded_double()) + + if we_are_jitted(): + self._idx_write.write_value(frame, nilObject) + + i = start.get_embedded_integer() + while i <= end_int: + driver_int.jit_merge_point(self=self) + self._idx_write.write_value(frame, Integer(i)) + self._do_expr.execute(frame) + if we_are_jitted(): + self._idx_write.write_value(frame, nilObject) + i += 1 + + return start + + def _execute_double(self, frame, start, end): + if isinstance(end, Integer): + end_d = float(end.get_embedded_integer()) + else: + assert isinstance(end, Double) + end_d = end.get_embedded_double() + + if we_are_jitted(): + self._idx_write.write_value(frame, nilObject) + + i = start.get_embedded_double() + while i <= end_d: + driver_double.jit_merge_point(self=self) + self._idx_write.write_value(frame, Double(i)) + self._do_expr.execute(frame) + if we_are_jitted(): + self._idx_write.write_value(frame, nilObject) + i += 1.0 + + return start diff --git a/src/som/interpreter/ast/nodes/variable_node.py b/src/som/interpreter/ast/nodes/variable_node.py index 28efe426..5a9bb42a 100644 --- a/src/som/interpreter/ast/nodes/variable_node.py +++ b/src/som/interpreter/ast/nodes/variable_node.py @@ -29,12 +29,6 @@ def _specialize(self): def handle_inlining(self, mgenc): if self._context_level == 0: - from som.compiler.ast.variable import Local - - # we got inlined - assert isinstance( - self.var, Local - ), "We are not currently inlining any blocks with arguments" self.var = mgenc.get_inlined_local(self.var, 0) else: self._context_level -= 1 @@ -73,6 +67,9 @@ def __init__(self, var, context_level, value_expr, source_section): def execute(self, frame): return self._specialize().execute(frame) + def write_value(self, frame, value): + self._specialize().write_value(frame, value) + def _specialize(self): return self.replace( self._var.get_initialized_write_node( @@ -161,6 +158,9 @@ def execute(self, frame): self.determine_block(frame).set_outer(self._frame_idx, value) return value + def write_value(self, frame, value): + self.determine_block(frame).set_outer(self._frame_idx, value) + class _LocalVariableNode(ExpressionNode): _immutable_fields_ = ["_frame_idx"] @@ -191,6 +191,9 @@ def execute(self, frame): write_inner(frame, self._frame_idx, val) return val + def write_value(self, frame, value): + write_inner(frame, self._frame_idx, value) + class LocalFrameVarReadNode(_LocalVariableNode): def execute(self, frame): @@ -205,3 +208,6 @@ def execute(self, frame): val = self._expr.execute(frame) write_frame(frame, self._frame_idx, val) return val + + def write_value(self, frame, value): + write_frame(frame, self._frame_idx, value) diff --git a/src/som/interpreter/bc/bytecodes.py b/src/som/interpreter/bc/bytecodes.py index 2fa56bf5..0659967b 100644 --- a/src/som/interpreter/bc/bytecodes.py +++ b/src/som/interpreter/bc/bytecodes.py @@ -5,8 +5,9 @@ class Bytecodes(object): # Bytecodes used by the Simple Object Machine (SOM) halt = 0 dup = halt + 1 + dup_second = dup + 1 - push_frame = dup + 1 + push_frame = dup_second + 1 push_frame_0 = push_frame + 1 push_frame_1 = push_frame_0 + 1 push_frame_2 = push_frame_1 + 1 @@ -50,7 +51,10 @@ class Bytecodes(object): pop_field_0 = pop_field + 1 pop_field_1 = pop_field_0 + 1 - send_1 = pop_field_1 + 1 + nil_frame = pop_field_1 + 1 + nil_inner = nil_frame + 1 + + send_1 = nil_inner + 1 send_2 = send_1 + 1 send_3 = send_2 + 1 send_n = send_3 + 1 @@ -76,13 +80,15 @@ class Bytecodes(object): jump_on_false_top_nil = jump_on_true_top_nil + 1 jump_on_true_pop = jump_on_false_top_nil + 1 jump_on_false_pop = jump_on_true_pop + 1 - jump_backward = jump_on_false_pop + 1 + jump_if_greater = jump_on_false_pop + 1 + jump_backward = jump_if_greater + 1 jump2 = jump_backward + 1 jump2_on_true_top_nil = jump2 + 1 jump2_on_false_top_nil = jump2_on_true_top_nil + 1 jump2_on_true_pop = jump2_on_false_top_nil + 1 jump2_on_false_pop = jump2_on_true_pop + 1 - jump2_backward = jump2_on_false_pop + 1 + jump2_if_greater = jump2_on_false_pop + 1 + jump2_backward = jump2_if_greater + 1 q_super_send_1 = jump2_backward + 1 q_super_send_2 = q_super_send_1 + 1 @@ -93,8 +99,9 @@ class Bytecodes(object): push_argument = push_local + 1 pop_local = push_argument + 1 pop_argument = pop_local + 1 + nil_local = pop_argument + 1 - invalid = pop_argument + 1 + invalid = nil_local + 1 def is_one_of(bytecode, candidates): @@ -150,12 +157,14 @@ def is_one_of(bytecode, candidates): Bytecodes.jump_on_true_pop, Bytecodes.jump_on_false_pop, Bytecodes.jump_on_false_top_nil, + Bytecodes.jump_if_greater, Bytecodes.jump_backward, Bytecodes.jump2, Bytecodes.jump2_on_true_top_nil, Bytecodes.jump2_on_true_pop, Bytecodes.jump2_on_false_pop, Bytecodes.jump2_on_false_top_nil, + Bytecodes.jump2_if_greater, Bytecodes.jump2_backward, ] @@ -177,6 +186,8 @@ def is_one_of(bytecode, candidates): Bytecodes.pop_inner_0, Bytecodes.pop_inner_1, Bytecodes.pop_inner_2, + Bytecodes.nil_frame, + Bytecodes.nil_inner, Bytecodes.q_super_send_1, Bytecodes.q_super_send_2, Bytecodes.q_super_send_3, @@ -198,6 +209,7 @@ def is_one_of(bytecode, candidates): _BYTECODE_LENGTH = [ 1, # halt 1, # dup + 1, # dup_second 3, # push_frame 3, # push_frame_0 3, # push_frame_1 @@ -231,6 +243,8 @@ def is_one_of(bytecode, candidates): 3, # pop_field 1, # pop_field_0 1, # pop_field_1 + 2, # nil_frame + 2, # nil_inner 2, # send_1 2, # send_2 2, # send_3 @@ -251,12 +265,14 @@ def is_one_of(bytecode, candidates): 3, # jump_on_false_top_nil 3, # jump_on_true_pop 3, # jump_on_false_pop + 3, # jump_if_greater 3, # jump_backward 3, # jump2 3, # jump2_on_true_top_nil 3, # jump2_on_false_top_nil 3, # jump2_on_true_pop 3, # jump2_on_false_pop + 3, # jump2_if_greater 3, # jump2_backward 2, # q_super_send_1 2, # q_super_send_2 @@ -267,6 +283,7 @@ def is_one_of(bytecode, candidates): 3, # push_argument 3, # pop_local 3, # pop_argument + 2, # nil_local ] diff --git a/src/som/interpreter/bc/interpreter.py b/src/som/interpreter/bc/interpreter.py index 4d983519..bc86b45f 100644 --- a/src/som/interpreter/bc/interpreter.py +++ b/src/som/interpreter/bc/interpreter.py @@ -16,7 +16,9 @@ from som.vm.globals import nilObject, trueObject, falseObject from som.vmobjects.array import Array from som.vmobjects.block_bc import BcBlock -from som.vmobjects.integer import int_0, int_1 +from som.vmobjects.integer import int_0, int_1, Integer +from som.vmobjects.double import Double + from rlib import jit from rlib.jit import promote, elidable_promote, we_are_jitted @@ -135,6 +137,11 @@ def interpret(method, frame, max_stack_size): stack_ptr += 1 stack[stack_ptr] = val + elif bytecode == Bytecodes.dup_second: + val = stack[stack_ptr - 1] + stack_ptr += 1 + stack[stack_ptr] = val + elif bytecode == Bytecodes.push_frame: stack_ptr += 1 stack[stack_ptr] = read_frame( @@ -314,6 +321,16 @@ def interpret(method, frame, max_stack_size): write_inner(frame, FRAME_AND_INNER_RCVR_IDX + 2, value) + elif bytecode == Bytecodes.nil_frame: + if we_are_jitted(): + idx = method.get_bytecode(current_bc_idx + 1) + write_frame(frame, idx, nilObject) + + elif bytecode == Bytecodes.nil_inner: + if we_are_jitted(): + idx = method.get_bytecode(current_bc_idx + 1) + write_inner(frame, idx, nilObject) + elif bytecode == Bytecodes.pop_field: field_idx = method.get_bytecode(current_bc_idx + 1) ctx_level = method.get_bytecode(current_bc_idx + 2) @@ -461,8 +478,6 @@ def interpret(method, frame, max_stack_size): elif bytecode == Bytecodes.inc: val = stack[stack_ptr] - from som.vmobjects.integer import Integer - from som.vmobjects.double import Double from som.vmobjects.biginteger import BigInteger if isinstance(val, Integer): @@ -477,8 +492,6 @@ def interpret(method, frame, max_stack_size): elif bytecode == Bytecodes.dec: val = stack[stack_ptr] - from som.vmobjects.integer import Integer - from som.vmobjects.double import Double from som.vmobjects.biginteger import BigInteger if isinstance(val, Integer): @@ -545,6 +558,23 @@ def interpret(method, frame, max_stack_size): stack[stack_ptr] = None stack_ptr -= 1 + elif bytecode == Bytecodes.jump_if_greater: + top = stack[stack_ptr] + top_2 = stack[stack_ptr - 1] + if ( + isinstance(top, Integer) + and isinstance(top_2, Integer) + and top.get_embedded_integer() > top_2.get_embedded_integer() + ) or ( + isinstance(top, Double) + and isinstance(top_2, Double) + and top.get_embedded_double() > top_2.get_embedded_double() + ): + stack[stack_ptr] = None + stack[stack_ptr - 1] = None + stack_ptr -= 2 + next_bc_idx = current_bc_idx + method.get_bytecode(current_bc_idx + 1) + elif bytecode == Bytecodes.jump_backward: next_bc_idx = current_bc_idx - method.get_bytecode(current_bc_idx + 1) jitdriver.can_enter_jit( @@ -614,6 +644,19 @@ def interpret(method, frame, max_stack_size): stack[stack_ptr] = None stack_ptr -= 1 + elif bytecode == Bytecodes.jump2_if_greater: + top = stack[stack_ptr] + top_2 = stack[stack_ptr - 1] + if top.get_embedded_integer() > top_2.get_embedded_integer(): + stack[stack_ptr] = None + stack[stack_ptr - 1] = None + stack_ptr -= 2 + next_bc_idx = ( + current_bc_idx + + method.get_bytecode(current_bc_idx + 1) + + (method.get_bytecode(current_bc_idx + 2) << 8) + ) + elif bytecode == Bytecodes.jump2_backward: next_bc_idx = current_bc_idx - ( method.get_bytecode(current_bc_idx + 1) @@ -669,6 +712,10 @@ def interpret(method, frame, max_stack_size): method.patch_variable_access(current_bc_idx) # retry bytecode after patching next_bc_idx = current_bc_idx + elif bytecode == Bytecodes.nil_local: + method.patch_variable_access(current_bc_idx) + # retry bytecode after patching + next_bc_idx = current_bc_idx else: _unknown_bytecode(bytecode, current_bc_idx, method) diff --git a/src/som/vmobjects/method.py b/src/som/vmobjects/method.py index 1d39c237..f32dafe4 100644 --- a/src/som/vmobjects/method.py +++ b/src/som/vmobjects/method.py @@ -6,12 +6,14 @@ class AbstractMethod(AbstractObject): _immutable_fields_ = [ "_signature", "_holder", + "_lexical_scope", ] - def __init__(self, signature): + def __init__(self, signature, lexical_scope): AbstractObject.__init__(self) self._signature = signature self._holder = None + self._lexical_scope = lexical_scope @staticmethod def is_primitive(): @@ -22,6 +24,9 @@ def is_invokable(): """We use this method to identify methods and primitives""" return True + def get_argument(self, idx, ctx_level): + return self._lexical_scope.get_argument(idx, ctx_level) + def get_holder(self): return self._holder diff --git a/src/som/vmobjects/method_ast.py b/src/som/vmobjects/method_ast.py index 5f8493bb..f4e94358 100644 --- a/src/som/vmobjects/method_ast.py +++ b/src/som/vmobjects/method_ast.py @@ -80,7 +80,6 @@ class AstMethod(AbstractMethod): "_size_inner", "_embedded_block_methods", "source_section", - "_lexical_scope", ] def __init__( @@ -94,7 +93,7 @@ def __init__( source_section, lexical_scope, ): - AbstractMethod.__init__(self, signature) + AbstractMethod.__init__(self, signature, lexical_scope) assert isinstance(arg_inner_access, list) make_sure_not_resized(arg_inner_access) @@ -108,8 +107,6 @@ def __init__( self.invokable = _Invokable(expr_or_sequence) - self._lexical_scope = lexical_scope - def set_holder(self, value): self._holder = value for method in self._embedded_block_methods: @@ -165,7 +162,7 @@ def invoke_args(node, rcvr, args): # pylint: disable=no-self-argument ) return node.invokable.expr_or_sequence.execute(frame) - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument mgenc.merge_into_scope(self._lexical_scope) self.invokable.expr_or_sequence.adapt_after_inlining(mgenc) return self.invokable.expr_or_sequence diff --git a/src/som/vmobjects/method_bc.py b/src/som/vmobjects/method_bc.py index 6af8c1f3..20f7e54a 100644 --- a/src/som/vmobjects/method_bc.py +++ b/src/som/vmobjects/method_bc.py @@ -15,6 +15,9 @@ emit_push_field_with_index, emit_pop_field_with_index, emit3_with_dummy, + emit_push_local, + emit_pop_local, + emit_nil_local, compute_offset, ) from som.interpreter.ast.frame import ( @@ -55,7 +58,6 @@ class BcAbstractMethod(AbstractMethod): "_arg_inner_access[*]", "_size_frame", "_size_inner", - "_lexical_scope", "_inlined_loops[*]", ] @@ -72,7 +74,7 @@ def __init__( lexical_scope, inlined_loops, ): - AbstractMethod.__init__(self, signature) + AbstractMethod.__init__(self, signature, lexical_scope) # Set the number of bytecodes in this method self._bytecodes = ["\x00"] * num_bytecodes @@ -89,8 +91,6 @@ def __init__( self._size_frame = size_frame self._size_inner = size_inner - self._lexical_scope = lexical_scope - self._inlined_loops = inlined_loops def get_number_of_locals(self): @@ -178,6 +178,13 @@ def patch_variable_access(self, bytecode_index): elif bc == Bytecodes.pop_local: var = self._lexical_scope.get_local(idx, ctx_level) self.set_bytecode(bytecode_index, var.get_pop_bytecode(ctx_level)) + elif bc == Bytecodes.nil_local: + var = self._lexical_scope.get_local(idx, 0) + if var.is_accessed_out_of_context(): + bytecode = Bytecodes.nil_inner + else: + bytecode = Bytecodes.nil_frame + self.set_bytecode(bytecode_index, bytecode) else: raise Exception("Unsupported bytecode?") assert ( @@ -243,8 +250,12 @@ def invoke_n(self, stack, stack_ptr): stack, stack_ptr, self._number_of_arguments, result ) - def inline(self, mgenc): + def merge_scope_into(self, mgenc): mgenc.merge_into_scope(self._lexical_scope) + + def inline(self, mgenc, merge_scope=True): + if merge_scope: + mgenc.merge_into_scope(self._lexical_scope) self._inline_into(mgenc) def _create_back_jump_heap(self): @@ -297,7 +308,7 @@ def _inline_into(self, mgenc): if bytecode == Bytecodes.halt: emit1(mgenc, bytecode, 0) - elif bytecode == Bytecodes.dup: + elif bytecode == Bytecodes.dup or bytecode == Bytecodes.dup_second: emit1(mgenc, bytecode, 1) elif ( @@ -308,8 +319,23 @@ def _inline_into(self, mgenc): ): idx = self.get_bytecode(i + 1) ctx_level = self.get_bytecode(i + 2) - assert ctx_level > 0 - if bytecode == Bytecodes.push_field: + + if ctx_level == 0: + assert ( + bytecode == Bytecodes.push_argument + or bytecode == Bytecodes.pop_argument + ), ( + "This should really be push or pop argument." + + " everything else should have a ctx_level > 0" + ) + + arg = self._lexical_scope.get_argument(idx, 0) + idx = mgenc.get_inlined_local_idx(arg, 0) + if bytecode == Bytecodes.push_argument: + emit_push_local(mgenc, idx, 0) + else: + emit_pop_local(mgenc, idx, 0) + elif bytecode == Bytecodes.push_field: emit_push_field_with_index(mgenc, idx, ctx_level - 1) elif bytecode == Bytecodes.pop_field: emit_pop_field_with_index(mgenc, idx, ctx_level - 1) @@ -344,6 +370,12 @@ def _inline_into(self, mgenc): else: emit3(mgenc, bytecode, idx, ctx_level, -1) + elif bytecode == Bytecodes.nil_local: + idx = self.get_bytecode(i + 1) + var = self._lexical_scope.get_local(idx, 0) + idx = mgenc.get_inlined_local_idx(var, 0) + emit_nil_local(mgenc, idx) + elif bytecode == Bytecodes.push_block: literal_idx = self.get_bytecode(i + 1) block_method = self._literals[literal_idx] @@ -425,9 +457,11 @@ def _inline_into(self, mgenc): bytecode == Bytecodes.jump or bytecode == Bytecodes.jump_on_true_top_nil or bytecode == Bytecodes.jump_on_false_top_nil + or bytecode == Bytecodes.jump_if_greater or bytecode == Bytecodes.jump2 or bytecode == Bytecodes.jump2_on_true_top_nil or bytecode == Bytecodes.jump2_on_false_top_nil + or bytecode == Bytecodes.jump2_if_greater ): # emit the jump, but instead of the offset, emit a dummy idx = emit3_with_dummy(mgenc, bytecode, 0) @@ -494,6 +528,7 @@ def adapt_after_outer_inlined(self, removed_ctx_level, mgenc_with_inlined): if ( bytecode == Bytecodes.halt or bytecode == Bytecodes.dup + or bytecode == Bytecodes.dup_second or bytecode == Bytecodes.push_block_no_ctx or bytecode == Bytecodes.push_constant or bytecode == Bytecodes.push_constant_0 @@ -520,12 +555,14 @@ def adapt_after_outer_inlined(self, removed_ctx_level, mgenc_with_inlined): or bytecode == Bytecodes.jump_on_true_pop or bytecode == Bytecodes.jump_on_false_top_nil or bytecode == Bytecodes.jump_on_false_pop + or bytecode == Bytecodes.jump_if_greater or bytecode == Bytecodes.jump_backward or bytecode == Bytecodes.jump2 or bytecode == Bytecodes.jump2_on_true_top_nil or bytecode == Bytecodes.jump2_on_true_pop or bytecode == Bytecodes.jump2_on_false_top_nil or bytecode == Bytecodes.jump2_on_false_pop + or bytecode == Bytecodes.jump2_if_greater or bytecode == Bytecodes.jump2_backward ): # don't use context @@ -542,6 +579,20 @@ def adapt_after_outer_inlined(self, removed_ctx_level, mgenc_with_inlined): ctx_level = self.get_bytecode(i + 2) if ctx_level > removed_ctx_level: self.set_bytecode(i + 2, ctx_level - 1) + elif ctx_level == removed_ctx_level and ( + bytecode == Bytecodes.push_argument + or bytecode == Bytecodes.pop_argument + ): + idx = self.get_bytecode(i + 1) + arg = self._lexical_scope.get_argument(idx, removed_ctx_level) + new_idx = mgenc_with_inlined.get_inlined_local_idx( + arg, removed_ctx_level + ) + if bytecode == Bytecodes.push_argument: + self.set_bytecode(i, Bytecodes.push_local) + else: + self.set_bytecode(i, Bytecodes.pop_local) + self.set_bytecode(i + 1, new_idx) elif bytecode == Bytecodes.push_block: literal_idx = self.get_bytecode(i + 1) @@ -566,6 +617,12 @@ def adapt_after_outer_inlined(self, removed_ctx_level, mgenc_with_inlined): elif ctx_level > removed_ctx_level: self.set_bytecode(i + 2, ctx_level - 1) + elif bytecode == Bytecodes.nil_local: + assert removed_ctx_level > 0, ( + "Don't need to adjust this bytecode, " + + "because it only operates on ctx_level==0" + ) + elif bytecode == Bytecodes.return_non_local: ctx_level = self.get_bytecode(i + 1) self.set_bytecode(i + 1, ctx_level - 1) @@ -668,7 +725,7 @@ def invoke_n(self, stack, stack_ptr): ) raise e - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): raise Exception( "Blocks should never handle non-local returns. " "So, this should not happen." diff --git a/src/som/vmobjects/method_trivial.py b/src/som/vmobjects/method_trivial.py index 510b92a2..50d9d0ec 100644 --- a/src/som/vmobjects/method_trivial.py +++ b/src/som/vmobjects/method_trivial.py @@ -4,6 +4,7 @@ emit_push_global, emit_push_field_with_index, ) +from som.compiler.ast.variable import Argument from som.interp_type import is_ast_interpreter from som.interpreter.ast.frame import FRAME_AND_INNER_RCVR_IDX from som.interpreter.bc.frame import stack_pop_old_arguments_and_push_result @@ -40,11 +41,12 @@ def get_number_of_signature_arguments(self): class LiteralReturn(AbstractTrivialMethod): - _immutable_fields_ = ["_value"] + _immutable_fields_ = ["_value", "source_section"] - def __init__(self, signature, value): - AbstractTrivialMethod.__init__(self, signature) + def __init__(self, signature, value, source_section): + AbstractTrivialMethod.__init__(self, signature, None) self._value = value + self.source_section = source_section def set_holder(self, value): self._holder = value @@ -70,22 +72,32 @@ def invoke_n(self, stack, stack_ptr): if is_ast_interpreter(): - def inline(self, _mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument from som.interpreter.ast.nodes.literal_node import LiteralNode - return LiteralNode(self._value) + if merge_scope: + self.merge_scope_into(mgenc) + return LiteralNode(self._value, self.source_section) else: - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument emit_push_constant(mgenc, self._value) + def merge_scope_into(self, mgenc): + mgenc.inline_as_locals([Argument("$inlinedI", 1, self.source_section)]) + + def get_argument(self, idx, ctx_level): + if idx == 1 and ctx_level == 0: + return Argument("$inlinedI", idx, self.source_section) + raise Exception("This should not happen") + class GlobalRead(AbstractTrivialMethod): _immutable_fields_ = ["_assoc?", "_global_name", "_context_level", "universe"] def __init__(self, signature, global_name, context_level, universe, assoc=None): - AbstractTrivialMethod.__init__(self, signature) + AbstractTrivialMethod.__init__(self, signature, None) self._assoc = assoc self._global_name = global_name self._context_level = context_level @@ -125,14 +137,14 @@ def invoke_n(self, stack, stack_ptr): if is_ast_interpreter(): - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument from som.interpreter.ast.nodes.global_read_node import create_global_node return create_global_node(self._global_name, self.universe, mgenc, None) else: - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument emit_push_global(mgenc, self._global_name) @@ -140,7 +152,7 @@ class FieldRead(AbstractTrivialMethod): _immutable_fields_ = ["_field_idx", "_context_level"] def __init__(self, signature, field_idx, context_level): - AbstractTrivialMethod.__init__(self, signature) + AbstractTrivialMethod.__init__(self, signature, None) self._field_idx = field_idx self._context_level = context_level @@ -170,14 +182,14 @@ def invoke_n(self, stack, stack_ptr): if is_ast_interpreter(): - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument from som.interpreter.ast.nodes.field_node import FieldReadNode return FieldReadNode(mgenc.get_self_read(), self._field_idx, None) else: - def inline(self, mgenc): + def inline(self, mgenc, merge_scope=True): # pylint: disable=unused-argument emit_push_field_with_index(mgenc, self._field_idx, self._context_level - 1) @@ -185,7 +197,7 @@ class FieldWrite(AbstractTrivialMethod): _immutable_fields_ = ["_field_idx", "_arg_idx"] def __init__(self, signature, field_idx, arg_idx): - AbstractTrivialMethod.__init__(self, signature) + AbstractTrivialMethod.__init__(self, signature, None) self._field_idx = field_idx self._arg_idx = arg_idx diff --git a/tests/test_ast_inlining.py b/tests/test_ast_inlining.py index e9d3c9a2..786745d2 100644 --- a/tests/test_ast_inlining.py +++ b/tests/test_ast_inlining.py @@ -23,6 +23,7 @@ IfInlinedNode, IfElseInlinedNode, ) +from som.interpreter.ast.nodes.specialized.literal_to_do import ToDoInlined from som.interpreter.ast.nodes.specialized.literal_while import WhileInlinedNode from som.interpreter.ast.nodes.variable_node import ( UninitializedReadNode, @@ -493,7 +494,7 @@ def test_to_do_block_block_inlined_self(cgenc, mgenc): b ifTrue: [ l2 := l2 + 1 ] ] ] )""", ) - block_a = ast._exprs[0]._arg_exprs[1]._value.invokable.expr_or_sequence + block_a = ast._exprs[0]._do_expr block_b_if_true = block_a._arg_exprs[0]._value.invokable.expr_or_sequence read_node = block_b_if_true._condition_expr @@ -502,13 +503,13 @@ def test_to_do_block_block_inlined_self(cgenc, mgenc): assert read_node.var.idx == 1 write_node = block_b_if_true._body_expr - assert write_node._context_level == 2 + assert write_node._context_level == 1 assert write_node._var._name == "l2" assert write_node._var.idx == 1 assert isinstance(write_node._value_expr, IntIncrementNode) read_l2_node = write_node._value_expr._rcvr_expr - assert read_l2_node._context_level == 2 + assert read_l2_node._context_level == 1 assert read_l2_node.var._name == "l2" assert read_l2_node.var.idx == 1 @@ -538,3 +539,11 @@ def test_inlining_of_or(mgenc, or_sel): ) assert isinstance(ast._exprs[0], OrInlinedNode) + + +def test_inlining_of_to_do(mgenc): + ast = parse_method(mgenc, "test = ( 1 to: 2 do: [:i | i ] )") + + assert isinstance(ast._exprs[0], ToDoInlined) + to_do = ast._exprs[0] + assert to_do._idx_var._name == "i" diff --git a/tests/test_bytecode_generation.py b/tests/test_bytecode_generation.py index 6290b020..7127f86d 100644 --- a/tests/test_bytecode_generation.py +++ b/tests/test_bytecode_generation.py @@ -1377,3 +1377,76 @@ def test_field_read_inlining(cgenc, mgenc): Bytecodes.return_self, ], ) + + +def test_inlining_of_to_do(mgenc): + bytecodes = method_to_bytecodes(mgenc, "test = ( 1 to: 2 do: [:i | i ] )") + + assert len(bytecodes) == 23 + check( + bytecodes, + [ + Bytecodes.push_1, + Bytecodes.push_constant_0, + Bytecodes.dup_second, # stack: Top[1, 2, 1] + Bytecodes.nil_local, + BC(Bytecodes.jump_if_greater, 17), # consume only on jump + Bytecodes.dup, + BC( + Bytecodes.pop_local, 0, 0 + ), # store the i into the local (arg becomes local after inlining) + BC( + Bytecodes.push_local, 0, 0 + ), # push the local on the stack as part of the block's code + Bytecodes.pop, # cleanup after block. + Bytecodes.inc, # increment top, the iteration counter + Bytecodes.nil_local, + BC( + Bytecodes.jump_backward, 14 + ), # jump back to the jump_if_greater bytecode + # jump_if_greater target + Bytecodes.return_self, + ], + ) + + +def test_to_do_block_block_inlined_self(cgenc, mgenc): + add_field(cgenc, "field") + bytecodes = method_to_bytecodes( + mgenc, + """ + test = ( + | l1 l2 | + 1 to: 2 do: [:a | + l1 do: [:b | + b ifTrue: [ + a. + l2 := l2 + 1 ] ] ] + )""", + ) + + assert len(bytecodes) == 27 + check( + bytecodes, + [ + Bytecodes.push_1, + Bytecodes.push_constant_0, + Bytecodes.dup_second, + Bytecodes.nil_local, + BC(Bytecodes.jump_if_greater, 21), + Bytecodes.dup, + BC(Bytecodes.pop_local, 2, 0), + BC(Bytecodes.push_local, 0, 0), + BC(Bytecodes.push_block, 2), + Bytecodes.send_2, + Bytecodes.pop, + Bytecodes.inc, + Bytecodes.nil_local, + BC(Bytecodes.jump_backward, 18), + Bytecodes.return_self, + ], + ) + + block_method = mgenc._literals[2] # pylint: disable=protected-access + block_bcs = block_method.get_bytecodes() + check(block_bcs, [(6, BC(Bytecodes.push_local, 2, 1)), Bytecodes.pop]) diff --git a/tests/test_printable_locations.py b/tests/test_printable_locations.py index da0caeb9..c0fed1b1 100644 --- a/tests/test_printable_locations.py +++ b/tests/test_printable_locations.py @@ -2,9 +2,14 @@ import pytest from rtruffle.source_section import SourceCoordinate, SourceSection +from som.compiler.ast.variable import Local from som.interpreter.ast.nodes.specialized.down_to_do_node import ( get_printable_location as pl_dtd, ) +from som.interpreter.ast.nodes.specialized.literal_to_do import ( + get_printable_location as pl_ltd, + ToDoInlined, +) from som.interpreter.ast.nodes.specialized.literal_while import ( get_printable_location_while as pl_while, WhileInlinedNode, @@ -41,6 +46,12 @@ def test_pl_dtd(method): assert pl_dtd(method) == "#to:do: Test>>test" +def test_pl_ltd(source_section): + node = ToDoInlined(None, None, None, Local("", 1, source_section), source_section) + + assert pl_ltd(node) == "#to:do: test.som:1:1" + + def test_while(source_section): node = WhileInlinedNode(None, None, None, source_section) assert pl_while(node) == "while test.som:1:1"