From cf4426fdb3d8800f058ae50f770db3479f9bf4a4 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 14 Dec 2023 12:00:24 +1100 Subject: [PATCH] Multinode computation. --- BMR/common.h | 1 + CHANGELOG.md | 18 + CONFIG | 1 + Compiler/GC/instructions.py | 27 +- Compiler/GC/types.py | 53 ++- Compiler/allocator.py | 70 +++- Compiler/comparison.py | 10 +- Compiler/compilerLib.py | 55 ++- Compiler/floatingpoint.py | 14 +- Compiler/instructions.py | 86 ++++- Compiler/instructions_base.py | 177 +++++++-- Compiler/library.py | 143 ++++--- Compiler/ml.py | 31 +- Compiler/oram.py | 51 ++- Compiler/program.py | 166 ++++++-- Compiler/types.py | 402 +++++++++++++++----- Dockerfile | 6 +- ECDSA/CurveElement.h | 1 - ECDSA/P256Element.cpp | 10 - ECDSA/P256Element.h | 6 - ExternalIO/bankers-bonus-client.cpp | 6 + ExternalIO/bankers-bonus-client.py | 21 +- ExternalIO/client.py | 19 +- FHE/Ciphertext.cpp | 13 +- FHE/Ciphertext.h | 1 - FHE/FHE_Keys.cpp | 23 +- FHE/FHE_Keys.h | 7 +- FHE/Ring_Element.h | 8 +- FHE/Rq_Element.cpp | 11 +- FHE/Rq_Element.h | 10 +- FHEOffline/DistKeyGen.cpp | 9 +- GC/Processor.h | 1 + GC/Processor.hpp | 7 + GC/TinyPrep.hpp | 2 +- GC/instructions.h | 4 +- Machines/ShamirMachine.hpp | 6 + Machines/emulate.cpp | 18 +- Machines/no-party.cpp | 3 +- Machines/spdz2k-party.cpp | 7 +- Makefile | 11 +- Math/Bit.h | 5 - Math/BitVec.h | 7 - Math/FixedVec.h | 6 - Math/ValueInterface.cpp | 2 + Math/Z2k.h | 2 - Math/Zp_Data.h | 2 +- Math/bigint.cpp | 6 - Math/bigint.h | 2 - Math/gf2n.h | 4 - Math/gfp.h | 4 - Math/gfpvar.cpp | 6 - Math/gfpvar.h | 2 - Networking/Player.cpp | 6 +- Networking/ServerSocket.cpp | 21 +- Networking/sockets.cpp | 6 +- Networking/sockets.h | 11 +- OT/BitMatrix.h | 2 + Processor/BaseMachine.cpp | 50 ++- Processor/BaseMachine.h | 6 + Processor/ExternalClients.cpp | 40 +- Processor/ExternalClients.h | 2 + Processor/FieldMachine.hpp | 2 +- Processor/Instruction.cpp | 8 +- Processor/Instruction.h | 7 +- Processor/Instruction.hpp | 70 +++- Processor/Machine.h | 4 + Processor/Machine.hpp | 32 +- Processor/Memory.h | 85 ++++- Processor/Memory.hpp | 63 +++ Processor/OnlineMachine.hpp | 12 - Processor/OnlineOptions.cpp | 72 +++- Processor/OnlineOptions.h | 6 +- Processor/OnlineOptions.hpp | 32 +- Processor/Processor.h | 28 +- Processor/Processor.hpp | 212 +++++++---- Processor/ProcessorBase.h | 13 +- Processor/RingMachine.hpp | 3 +- Processor/instructions.h | 32 +- Programs/Source/mnist_full_B.mpc | 2 +- Programs/Source/multinode_example_main.py | 34 ++ Programs/Source/multinode_example_worker.py | 21 + Programs/Source/tutorial.mpc | 3 + Protocols/AtlasShare.h | 5 + Protocols/FakeProtocol.h | 24 +- Protocols/FakeShare.h | 1 + Protocols/Hemi.h | 2 +- Protocols/Hemi.hpp | 10 +- Protocols/HemiShare.h | 5 + Protocols/MAC_Check.h | 11 +- Protocols/MAC_Check.hpp | 8 +- Protocols/MalRepRingPrep.hpp | 6 + Protocols/NoShare.h | 11 + Protocols/Rep3Shuffler.h | 15 +- Protocols/Rep3Shuffler.hpp | 26 +- Protocols/Rep4Share.h | 1 + Protocols/Replicated.h | 5 +- Protocols/Replicated.hpp | 8 - Protocols/SecureShuffle.h | 34 +- Protocols/SecureShuffle.hpp | 55 ++- Protocols/ShamirShare.h | 5 + Protocols/Share.h | 6 +- Protocols/ShareInterface.h | 2 + Protocols/ShuffleSacrifice.h | 1 + Protocols/ShuffleSacrifice.hpp | 2 +- Protocols/TemiShare.h | 5 + README.md | 13 +- Scripts/compile-emulate.py | 2 +- Scripts/compile-run.py | 16 +- Scripts/memory-usage.py | 14 +- Scripts/run-common.sh | 10 + Scripts/test_ecdsa.sh | 1 + Tools/DiskVector.cpp | 44 +++ Tools/DiskVector.h | 84 ++++ Tools/Exceptions.cpp | 7 + Tools/Exceptions.h | 6 + Tools/parse.h | 8 + Yao/YaoPlayer.cpp | 10 + azure-pipelines.yml | 2 +- deps/libOTe | 2 +- doc/compilation.rst | 11 +- doc/index.rst | 5 +- doc/instructions.rst | 2 + doc/io.rst | 9 +- doc/journey.rst | 228 +++++++++++ doc/machine-learning.rst | 2 +- doc/multinode.rst | 65 ++++ doc/non-linear.rst | 5 + doc/preprocessing.rst | 17 +- doc/troubleshooting.rst | 22 +- doc/utils.rst | 23 ++ 130 files changed, 2505 insertions(+), 799 deletions(-) create mode 100644 Programs/Source/multinode_example_main.py create mode 100644 Programs/Source/multinode_example_worker.py create mode 100644 Tools/DiskVector.cpp create mode 100644 Tools/DiskVector.h create mode 100644 doc/journey.rst create mode 100644 doc/multinode.rst create mode 100644 doc/utils.rst diff --git a/BMR/common.h b/BMR/common.h index de9ffad4e..27bced3c9 100644 --- a/BMR/common.h +++ b/BMR/common.h @@ -10,6 +10,7 @@ #include #include #include +#include using namespace std; #include "Tools/CheckVector.h" diff --git a/CHANGELOG.md b/CHANGELOG.md index 313bd3de3..6abb6df67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.8 (December 14, 2023) + +- Functionality for multiple nodes per party +- Functionality to use disk space for high-level data structures +- True division is always fixed-point division (similar to Python 3) +- Compiler option to optimize for specific protocol +- Cleartext permutation +- Faster compilation and lower bytecode size +- Functionality to output secret shares from high-level code +- Run-time command-line arguments accessible from high-level code +- Client connection setup specifies cleartext domain +- Compile-time parameter for connection timeout +- Prevent connections from timing out (@ParallelogramPal) +- More ECDSA examples +- More flexible multiplication instruction +- Dot product instruction supports several operations at once +- Example-based virtual machine explanation + ## 0.3.7 (August 14, 2023) - Path Oblivious Heap (@tskovlund) diff --git a/CONFIG b/CONFIG index f6f436294..ba108166c 100644 --- a/CONFIG +++ b/CONFIG @@ -87,6 +87,7 @@ LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS) LDLIBS += $(BREW_LDLIBS) LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto +LDLIBS += -lboost_filesystem -lboost_iostreams CFLAGS += -I./local/include diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 4fc2fe7c4..1b53f9300 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -80,17 +80,17 @@ class ClearBitsAF(base.RegisterArgFormat): CONVCBITVEC = 0x231, ) -class BinaryVectorInstruction(base.Instruction): - is_vec = lambda self: True +class BinaryCiscable(base.Ciscable): + pass - def copy(self, size, subs): - return type(self)(*self.get_new_args(size, subs)) +class BinaryVectorInstruction(BinaryCiscable): + is_vec = lambda self: True class NonVectorInstruction(base.Instruction): is_vec = lambda self: False def __init__(self, *args, **kwargs): - assert(args[0].n <= args[0].unit) + assert(args[0].n is None or args[0].n <= args[0].unit) super(NonVectorInstruction, self).__init__(*args, **kwargs) class NonVectorInstruction1(base.Instruction): @@ -163,7 +163,7 @@ def add_usage(self, req_node): sum(int(math.ceil(x / 64)) for x in self.args[::4])) class andrsvec(base.VarArgsInstruction, base.Mergeable, - base.DynFormatInstruction): + base.DynFormatInstruction, BinaryCiscable): """ Constant-vector AND of secret bit registers (vectorized version). :param: total number of arguments to follow (int) @@ -206,6 +206,9 @@ def add_usage(self, req_node): req_node.increment(('bit', 'triple'), size * (n - 3) // 2) req_node.increment(('bit', 'mixed'), size) + def copy(self, size, subs): + return type(self)(*self.get_new_args(size, subs)) + class ands(BinaryVectorInstruction): """ Bitwise AND of secret bit register vector. @@ -306,7 +309,7 @@ class bitcoms(NonVectorInstruction, base.VarArgsInstruction): arg_format = tools.chain(['sbw'], itertools.repeat('sb')) class bitdecc(NonVectorInstruction, base.VarArgsInstruction): - """ Secret bit register decomposition. + """ Clear bit register decomposition. :param: number of arguments to follow / number of bits plus one (int) :param: source (sbit) @@ -513,8 +516,8 @@ class convcbitvec(BinaryVectorInstruction): """ code = opcodes['CONVCBITVEC'] arg_format = ['int','ciw','cb'] - def __init__(self, *args): - super(convcbitvec, self).__init__(*args) + def __init__(self, *args, **kwargs): + super(convcbitvec, self).__init__(*args, **kwargs) assert(args[2].n == args[0]) args[1].set_size(args[0]) @@ -546,14 +549,14 @@ def __init__(self, *args, **kwargs): super(split_class, self).__init__(*args, **kwargs) assert (len(args) - 2) % args[0] == 0 -class movsb(NonVectorInstruction): +class movsb(BinaryVectorInstruction): """ Copy secret bit register. :param: destination (sbit) :param: source (sbit) """ code = opcodes['MOVSB'] - arg_format = ['sbw','sb'] + arg_format = ['int', 'sbw','sb'] class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ Secret bit register vector transpose. The first destination vector @@ -568,8 +571,6 @@ class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ code = opcodes['TRANS'] is_vec = lambda self: True - def __init__(self, *args): - super(trans, self).__init__(*args) @classmethod def dynamic_arg_format(cls, args): diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 0c79091f1..acbd354ed 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -97,8 +97,9 @@ def bit_decompose_clear(a, n_bits): cbits.conv_cint_vec(a, *res) return res @classmethod - def malloc(cls, size, creator_tape=None): - return Program.prog.malloc(size, cls, creator_tape=creator_tape) + def malloc(cls, size, creator_tape=None, **kwargs): + return Program.prog.malloc(size, cls, creator_tape=creator_tape, + **kwargs) @staticmethod def n_elements(): return 1 @@ -254,6 +255,18 @@ def expand(self, length): return self.get_type(length).bit_compose([self] * length) else: raise CompilerError('cannot expand from %s to %s' % (self.n, length)) + @classmethod + def new_vector(cls, size): + return cls.get_type(size)() + @classmethod + def concat(cls, parts): + return cls.bit_compose( + sum([part.bit_decompose() for part in parts], [])) + def copy_from_part(self, source, base, size): + self.mov(self, + self.bit_compose(source.bit_decompose()[base:base + size])) + def vector_size(self): + return self.n class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -425,7 +438,7 @@ def conv_regint_by_bit(cls, n, res, other): tmp = cbits.get_type(n)() tmp.conv_regint_by_bit(n, tmp, other) res.load_other(tmp) - mov = inst.movsb + mov = staticmethod(lambda x, y: inst.movsb(x.n, x, y)) types = {} def __init__(self, *args, **kwargs): bits.__init__(self, *args, **kwargs) @@ -1048,6 +1061,9 @@ def f(res): class sbit(bit, sbits): """ Single secret bit. """ + @classmethod + def get_type(cls, length): + return sbits.get_type(length) def if_else(self, x, y): """ Non-vectorized oblivious selection:: @@ -1301,6 +1317,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ bit_extend = staticmethod(_complement_two_extend) + mul_functions = {} @classmethod def popcnt_bits(cls, bits): return sbitvec.from_vec(bits).popcnt() @@ -1326,20 +1343,42 @@ def __mul__(self, other): elif isinstance(other, sbitfixvec): return NotImplemented my_bits, other_bits = self.expand(other, False) - matrix = [] m = float('inf') + uniform = True for x in itertools.chain(my_bits, other_bits): try: + uniform &= type(x) == type(my_bits[0]) and x.n == my_bits[0].n m = min(m, x.n) except: pass + if uniform and Program.prog.options.cisc: + bl = len(my_bits) + key = bl, len(other_bits) + if key not in self.mul_functions: + def instruction(*args): + res = self.binary_mul(args[bl:2 * bl], args[2 * bl:], + args[0].n) + for x, y in zip(res, args): + x.mov(y, x) + instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits)) + self.mul_functions[key] = instructions_base.cisc(instruction, + bl) + res = [sbits.get_type(m)() for i in range(bl)] + self.mul_functions[key](*(res + my_bits + other_bits)) + return self.from_vec(res) + else: + return self.binary_mul(my_bits, other_bits, m) + @classmethod + def binary_mul(cls, my_bits, other_bits, m): + matrix = [] for i, b in enumerate(other_bits): if m == 1: - matrix.append([x * b for x in my_bits[:len(self.v)-i]]) + matrix.append([x * b for x in my_bits[:len(my_bits)-i]]) else: - matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v) + matrix.append(( + sbitvec.from_vec(my_bits[:len(my_bits)-i]) * b).v) v = sbitint.wallace_tree_from_matrix(matrix) - return self.from_vec(v[:len(self.v)]) + return cls.from_vec(v[:len(my_bits)]) __rmul__ = __mul__ reduce_after_mul = lambda x: x def TruncMul(self, other, k, m, kappa=None, nearest=False): diff --git a/Compiler/allocator.py b/Compiler/allocator.py index fe1848035..f2154cabe 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -104,9 +104,10 @@ def consolidate(self): break class AllocPool: - def __init__(self): + def __init__(self, parent=None): self.ranges = defaultdict(lambda: [AllocRange()]) self.by_base = {} + self.parent = parent def alloc(self, reg_type, size): for r in self.ranges[reg_type]: @@ -116,8 +117,17 @@ def alloc(self, reg_type, size): return res def free(self, reg): - r = self.by_base.pop((reg.reg_type, reg.i)) - r.free(reg.i, reg.size) + try: + r = self.by_base.pop((reg.reg_type, reg.i)) + r.free(reg.i, reg.size) + except KeyError: + try: + self.parent.free(reg) + except: + if program.Program.prog.options.debug: + print('Error with freeing register with trace:') + print(util.format_trace(reg.caller)) + print() def new_ranges(self, min_usage): for t, n in min_usage.items(): @@ -133,7 +143,10 @@ def consolidate(self): rr.consolidate() def n_fragments(self): - return max(len(r) for r in self.ranges) + if self.ranges: + return max(len(r) for r in self.ranges) + else: + return 0 class StraightlineAllocator: """Allocate variables in a straightline program using n registers. @@ -146,6 +159,7 @@ def __init__(self, n, program): assert(n == REG_MAX) self.program = program self.old_pool = None + self.unused = defaultdict(lambda: 0) def alloc_reg(self, reg, free): base = reg.vectorbase @@ -195,7 +209,8 @@ def dealloc_reg(self, reg, inst, free): for x in itertools.chain(dup.duplicates, base.duplicates): to_check.add(x) - free.free(base) + if reg not in self.program.base_addresses: + free.free(base) if inst.is_vec() and base.vector: self.defined[base] = inst for i in base.vector: @@ -220,8 +235,11 @@ def process(self, program, alloc_pool): if unused_regs and len(unused_regs) == len(list(i.get_def())) and \ self.program.verbose: # only report if all assigned registers are unused - print("Register(s) %s never used, assigned by '%s' in %s" % \ - (unused_regs,i,format_trace(i.caller))) + self.unused[type(i).__name__] += 1 + if self.program.verbose > 1: + print( + "Register(s) %s never used, assigned by '%s' in %s" % \ + (unused_regs,i,format_trace(i.caller))) for j in i.get_used(): self.alloc_reg(j, alloc_pool) @@ -277,6 +295,7 @@ def p(sizes): x = reg.reg_type, reg.size print('Used registers: ', end='') p(sizes) + print('Unused instructions:', dict(self.unused)) def determine_scope(block, options): last_def = defaultdict_by_id(lambda: -1) @@ -421,6 +440,7 @@ def dependency_graph(self, merge_classes): last = defaultdict(lambda: defaultdict(lambda: None)) last_open = deque() last_input = defaultdict(lambda: [None, None]) + mem_scopes = defaultdict_by_id(lambda: MemScope()) depths = [0] * len(block.instructions) self.depths = depths @@ -429,6 +449,12 @@ def dependency_graph(self, merge_classes): self.sources = [] self.real_depths = [0] * len(block.instructions) round_type = {} + shuffles = defaultdict_by_id(set) + + class MemScope: + def __init__(self): + self.read = [] + self.write = [] def add_edge(i, j): if i in (-1, j): @@ -581,14 +607,20 @@ def keep_text_order(inst, n): depths[n] = depth if isinstance(instr, ReadMemoryInstruction): - if options.preserve_mem_order or instr._protect: + if options.preserve_mem_order: strict_mem_access(n, last_mem_read, last_mem_write) - elif not options.preserve_mem_order: + elif instr._protect: + scope = mem_scopes[instr._protect] + strict_mem_access(n, scope.read, scope.write) + if not options.preserve_mem_order: mem_access(n, instr, last_mem_read_of, last_mem_write_of) elif isinstance(instr, WriteMemoryInstruction): - if options.preserve_mem_order or instr._protect: + if options.preserve_mem_order: strict_mem_access(n, last_mem_write, last_mem_read) - elif not options.preserve_mem_order: + elif instr._protect: + scope = mem_scopes[instr._protect] + strict_mem_access(n, scope.write, scope.read) + if not options.preserve_mem_order: mem_access(n, instr, last_mem_write_of, last_mem_read_of) elif isinstance(instr, matmulsm): if options.preserve_mem_order: @@ -608,6 +640,11 @@ def keep_text_order(inst, n): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): keep_order(instr, n, StackInstruction) + elif isinstance(instr, applyshuffle): + shuffles[instr.args[3]].add(n) + elif isinstance(instr, delshuffle): + for i_inst in shuffles[instr.args[0]]: + add_edge(i_inst, n) if not G.pred[n]: self.sources.append(n) @@ -683,6 +720,7 @@ def __init__(self): self.cache = util.dict_by_id() self.offset_cache = util.dict_by_id() self.rev_offset_cache = {} + self.range_cache = util.dict_by_id() def add_offset(self, res, new_base, new_offset): self.offset_cache[res] = new_base, new_offset @@ -693,6 +731,12 @@ def run(self, instructions, program): for i, inst in enumerate(instructions): if isinstance(inst, ldint_class): self.cache[inst.args[0]] = inst.args[1] + elif isinstance(inst, incint): + if inst.args[2] == 1 and inst.args[3] == 1 and \ + inst.args[4] == len(inst.args[0]) and \ + inst.args[1] in self.cache: + self.range_cache[inst.args[0]] = \ + len(inst.args[0]), self.cache[inst.args[1]] elif isinstance(inst, IntegerInstruction): if inst.args[1] in self.cache and inst.args[2] in self.cache: res = inst.op(self.cache[inst.args[1]], @@ -731,6 +775,10 @@ def f(base, delta_reg): base, offset = self.offset_cache[inst.args[1]] addr = self.rev_offset_cache[base.i, offset] inst.args[1] = addr + elif inst.args[1] in self.range_cache: + size, base = self.range_cache[inst.args[1]] + if size == len(inst.args[0]): + instructions[i] = inst.get_direct(base) elif type(inst) == convint_class: if inst.args[1] in self.cache: res = self.cache[inst.args[1]] diff --git a/Compiler/comparison.py b/Compiler/comparison.py index cf818570a..6a1d76023 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -65,6 +65,8 @@ def ld2i(c, n): movc(c, t1) def require_ring_size(k, op): + if not program.options.ring: + return if int(program.options.ring) < k: msg = 'ring size too small for %s, compile ' \ 'with \'-R %d\' or more' % (op, k) @@ -140,7 +142,7 @@ def TruncRing(d, a, k, m, signed): high = sint.conv(carries[length]) else: if m == 1: - low = x[1][1] + low = x[0][1] high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \ sint.conv(x[0][-1]) else: @@ -181,7 +183,7 @@ def TruncLeakyInRing(a, k, m, signed): if k == m: return 0 assert k > m - assert int(program.options.ring) >= k + require_ring_size(k, 'leaky truncation') from .types import sint, intbitint, cint, cgf2n n_bits = k - m n_shift = int(program.options.ring) - n_bits @@ -228,7 +230,7 @@ def Mod2m(a_prime, a, k, m, kappa, signed): movs(a_prime, program.non_linear.mod2m(a, k, m, signed)) def Mod2mRing(a_prime, a, k, m, signed): - assert(int(program.options.ring) >= k) + require_ring_size(k, 'modulo power of two') from Compiler.types import sint, intbitint, cint shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) @@ -404,7 +406,7 @@ def carry(b, a, compute_p=True): return b if b is None: return a - t = [program.curr_block.new_reg('s') for i in range(3)] + t = [None] * 3 if compute_p: t[0] = a[0].bit_and(b[0]) t[2] = a[0].bit_and(b[1]) + a[1] diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index f34674a7b..b2e88b5d8 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -13,12 +13,28 @@ class Compiler: - def __init__(self, custom_args=None, usage=None, execute=False): + def __init__(self, custom_args=None, usage=None, execute=False, + split_args=False): if usage: self.usage = usage else: self.usage = "usage: %prog [options] filename [args]" self.execute = execute + self.runtime_args = [] + + if split_args: + if custom_args is None: + args = sys.argv + else: + args = custom_args + try: + split = args.index('--') + except ValueError: + split = len(args) + + custom_args = args[1:split] + self.runtime_args = args[split + 1:] + self.custom_args = custom_args self.build_option_parser() self.VARS = {} @@ -148,7 +164,8 @@ def build_option_parser(self): "--prime", dest="prime", default=defaults.prime, - help="prime modulus (default: not specified)", + help="use bit decomposition with a specifed prime modulus " + "for non-linear computation (default: use the masking approach)", ) parser.add_option( "-I", @@ -235,13 +252,31 @@ def build_option_parser(self): dest="hostfile", help="hosts to execute with", ) + else: + parser.add_option( + "-E", + "--execute", + dest="execute", + help="protocol to optimize for", + ) self.parser = parser def parse_args(self): self.options, self.args = self.parser.parse_args(self.custom_args) if self.execute: if not self.options.execute: - raise CompilerError("must give name of protocol with '-E'") + if len(self.args) > 1: + self.options.execute = self.args.pop(0) + else: + self.parser.error("missing protocol name") + if self.options.hostfile: + try: + open(self.options.hostfile) + except: + print('hostfile %s not found' % self.options.hostfile, + file=sys.stderr) + exit(1) + if self.options.execute: protocol = self.options.execute if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \ protocol.find("brain") >= 0 or protocol == "emulate": @@ -268,14 +303,14 @@ def parse_args(self): def build_program(self, name=None): self.prog = Program(self.args, self.options, name=name) - if self.execute: + if self.options.execute: if self.options.execute in \ ("emulate", "ring", "rep-field", "rep4-ring"): self.prog.use_trunc_pr = True if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"): self.prog.use_split(3) if self.options.execute in ("semi2k",): - self.prog.use_split(2) + self.prog.use_split(int(os.getenv("PLAYERS", 2))) if self.options.execute in ("rep4-ring",): self.prog.use_split(4) @@ -476,7 +511,9 @@ def executable_from_protocol(protocol): else: return protocol + "-party.x" - def local_execution(self, args=[]): + def local_execution(self, args=None): + if args is None: + args = self.runtime_args executable = self.executable_from_protocol(self.options.execute) if not os.path.exists("%s/%s" % (self.root, executable)): print("Creating binary for virtual machine...") @@ -488,9 +525,13 @@ def local_execution(self, args=[]): "Note that compilation requires a few GB of RAM.") vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute) sys.stdout.flush() + print("Compilation finished, running program...", file=sys.stderr) + sys.stderr.flush() os.execl(vm, vm, self.prog.name, *args) - def remote_execution(self, args=[]): + def remote_execution(self, args=None): + if args is None: + args = self.runtime_args vm = self.executable_from_protocol(self.options.execute) hosts = list(x.strip() for x in filter(None, open(self.options.hostfile))) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index bd5c13844..0dcf4f818 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -286,6 +286,7 @@ def BitDecRingRaw(a, k, m): bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits +@instructions_base.bit_cisc def BitDecRing(a, k, m): bits = BitDecRingRaw(a, k, m) # reversing to reduce number of rounds @@ -304,6 +305,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): instructions_base.reset_global_vector_size() return res +@instructions_base.bit_cisc def BitDecField(a, k, m, kappa, bits_to_compute=None): res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute) return [types.sintbit.conv(bit) for bit in res] @@ -358,7 +360,6 @@ def B2U_from_Pow2(pow2a, l, kappa): def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): """ Oblivious truncation by secret m """ prog = program.Program.prog - kappa = kappa or prog.security if util.is_constant(m) and not compute_modulo: # cheaper res = type(a)(size=a.size) @@ -371,6 +372,8 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): return a * (1 - m) if program.Program.prog.options.ring and not compute_modulo: return TruncInRing(a, l, Pow2(m, l, kappa)) + else: + kappa = kappa or program.Program.prog.security r = [types.sint() for i in range(l)] r_dprime = types.sint(0) r_prime = types.sint(0) @@ -409,6 +412,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): b = shifted - d return b +@instructions_base.ret_cisc def TruncInRing(to_shift, l, pow2m): n_shift = int(program.Program.prog.options.ring) - l bits = BitDecRing(to_shift, l, l) @@ -433,11 +437,7 @@ def SplitInRing(a, l, m): def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): t = comparison.TruncRoundNearest(a, length, length - target_length, kappa) overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa) - if program.Program.prog.options.ring: - s = (1 - overflow) * t + \ - comparison.TruncLeakyInRing(overflow * t, length, 1, False) - else: - s = (1 - overflow) * t + overflow * t / 2 + s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False) return s, overflow def Int2FL(a, gamma, l, kappa=None): @@ -555,7 +555,7 @@ def TruncPrField(a, k, m, kappa=None): c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime - d = (a - a_prime) / two_to_m + d = (a - a_prime).field_div(two_to_m) return d @instructions_base.ret_cisc diff --git a/Compiler/instructions.py b/Compiler/instructions.py index f3cb6ea66..f97d84121 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -345,6 +345,17 @@ class starg(base.Instruction): code = base.opcodes['STARG'] arg_format = ['ci'] +@base.vectorize +class cmdlinearg(base.Instruction): + """ Load command-line argument. + + :param: dest (regint) + :param: index (regint) + + """ + code = base.opcodes['CMDLINEARG'] + arg_format = ['ciw','ci'] + @base.gf2n class reqbl(base.Instruction): """ Requirement on computation modulus. Minimal bit length of prime if @@ -654,7 +665,7 @@ class picks(base.VectorInstruction): def __init__(self, *args): super(picks, self).__init__(*args) assert 0 <= args[2] < len(args[1]) - assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1]) + assert 0 <= args[2] + args[3] * (len(args[0]) - 1) < len(args[1]) class concats(base.VectorInstruction): """ Concatenate vectors. @@ -1630,6 +1641,16 @@ class print_reg_plain(base.IOInstruction): code = base.opcodes['PRINTREGPLAIN'] arg_format = ['c'] +class print_reg_plains(base.IOInstruction): + """ Output secret register. + + :param: source (sint) + + """ + __slots__ = [] + code = base.opcodes['PRINTREGPLAINS'] + arg_format = ['s'] + class cond_print_plain(base.IOInstruction): """ Conditionally output clear register (with precision). Outputs :math:`x \cdot 2^p` where :math:`p` is the precision. @@ -1860,6 +1881,19 @@ class acceptclientconnection(base.IOInstruction): code = base.opcodes['ACCEPTCLIENTCONNECTION'] arg_format = ['ciw', 'ci'] +class initclientconnection(base.IOInstruction): + """ Initialize connection. + + :param: client id destination (regint) + :param: port number (regint) + :param: my client id (regint) + :param: hostname (variable string) + + """ + __slots__ = [] + code = base.opcodes['INITCLIENTCONNECTION'] + arg_format = ['ciw', 'ci', 'ci', 'varstr'] + class closeclientconnection(base.IOInstruction): """ Close connection to client. @@ -1941,7 +1975,7 @@ class fixinput(base.PublicFileIOInstruction): :param: player (int) :param: destination (cint) - :param: exponent (int) + :param: exponent (int, for float/double) / byte length (1/8, for integer) :param: input type (0: 64-bit integer, 1: float, 2: double) """ @@ -2284,30 +2318,30 @@ def merge(self, other): self.args += other.args[1:] @base.gf2n -@base.vectorize -class muls(base.VarArgsInstruction, base.DataInstruction): +class muls(base.VarArgsInstruction, base.DataInstruction, base.Ciscable): """ (Element-wise) multiplication of secret registers (vectors). - :param: number of arguments to follow (multiple of three) + :param: number of arguments to follow (multiple of four) + :param: vector size (int) :param: result (sint) :param: factor (sint) :param: factor (sint) - :param: (repeat the last three)... + :param: (repeat the last four)... """ __slots__ = [] code = base.opcodes['MULS'] - arg_format = tools.cycle(['sw','s','s']) + arg_format = tools.cycle(['int','sw','s','s']) data_type = 'triple' + is_vec = lambda self: True - def get_repeat(self): - return len(self.args) // 3 + def __init__(self, *args, **kwargs): + super(muls_class, self).__init__(*args, **kwargs) + for i in range(0, len(args), 4): + for j in range(3): + assert args[i + j + 1].size == args[i] - def merge_id(self): - # can merge different sizes - # but not if large - if self.get_size() is None or self.get_size() > 100: - return type(self), self.get_size() - return type(self) + def get_repeat(self): + return sum(self.args[::4]) # def expand(self): # s = [program.curr_block.new_reg('s') for i in range(9)] @@ -2324,6 +2358,16 @@ def merge_id(self): # adds(s[8], s[7], s[6]) # addm(self.args[0], s[8], c[2]) +# compatibility +try: + vmuls = muls_class + muls_bak = muls + muls = lambda *args: muls_bak(args[0].size, *args) + vgmuls = gmuls_class = gmuls + gmuls = lambda *args: gmuls_class(args[0].size, *args) +except NameError: + pass + @base.gf2n class mulrs(base.VarArgsInstruction, base.DataInstruction): """ Constant-vector multiplication of secret registers. @@ -2403,8 +2447,8 @@ def gf2n_arg_format(self): return self.arg_format() def get_repeat(self): - return sum(self.args[i] // 2 - for i, n in self.bases(iter(self.args))) * self.get_size() + return sum(self.args[i] // 2 - 1 + for i, n in self.bases(iter(self.args))) def get_def(self): return [self.args[i + 1] for i, n in self.bases(iter(self.args))] @@ -2421,7 +2465,7 @@ class matmul_base(base.DataInstruction): def get_repeat(self): return reduce(operator.mul, self.args[3:6]) -class matmuls(matmul_base): +class matmuls(matmul_base, base.Mergeable): """ Secret matrix multiplication from registers. All matrices are represented as vectors in row-first order. @@ -2433,7 +2477,11 @@ class matmuls(matmul_base): :param: number of columns in second factor and result (int) """ code = base.opcodes['MATMULS'] - arg_format = ['sw','s','s','int','int','int'] + arg_format = itertools.cycle(['sw','s','s','int','int','int']) + + def get_repeat(self): + return sum(reduce(operator.mul, self.args[i + 3:i + 6]) + for i in range(0, len(self.args), 6)) class matmulsm(matmul_base): """ Secret matrix multiplication reading directly from memory. diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 57ff46197..9e88ec58b 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -67,6 +67,7 @@ USE_EDABIT = 0xE5, USE_MATMUL = 0x1F, ACTIVE = 0xE9, + CMDLINEARG = 0xEB, # Addition ADDC = 0x20, ADDS = 0x21, @@ -152,7 +153,7 @@ LISTEN = 0x6c, ACCEPTCLIENTCONNECTION = 0x6d, CLOSECLIENTCONNECTION = 0x6e, - READCLIENTPUBLICKEY = 0x6f, + INITCLIENTCONNECTION = 0x6f, # Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -196,6 +197,7 @@ PRINTREG = 0XB1, RAND = 0xB2, PRINTREGPLAIN = 0xB3, + PRINTREGPLAINS = 0xEA, PRINTCHR = 0xB4, PRINTSTR = 0xB5, PUBINPUT = 0xB6, @@ -422,28 +424,31 @@ def maybe_gf2n_instruction(*args, **kwargs): class Mergeable: pass -def cisc(function): +def cisc(function, n_outputs=1): class MergeCISC(Mergeable): instructions = {} + functions = {} def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs - self.security = program.security + self.security = program._security self.calls = [(args, kwargs)] self.params = [] self.used = [] - for arg in self.args[1:]: + for arg in self.args[n_outputs:]: if isinstance(arg, program.curr_tape.Register): self.used.append(arg) self.params.append(type(arg)) else: self.params.append(arg) self.function = function + self.caller = None program.curr_block.instructions.append(self) def get_def(self): - return [call[0][0] for call in self.calls] + return sum(([call[0][i] for call in self.calls] + for i in range(n_outputs)), []) def get_used(self): return self.used @@ -460,7 +465,7 @@ def merge(self, other): self.used += other.used def get_size(self): - return self.args[0].size + return self.args[0].vector_size() def new_instructions(self, size, regs): if self.merge_id() not in self.instructions: @@ -474,11 +479,11 @@ def new_instructions(self, size, regs): args = [] for arg in self.args: try: - args.append(type(arg)(size=None)) + args.append(arg.new_vector(size=None)) except: args.append(arg) program.options.cisc = False - old_security = program.security + old_security = program._security program.security = self.security self.function(*args, **self.kwargs) program.security = old_security @@ -490,7 +495,8 @@ def new_instructions(self, size, regs): from Compiler.allocator import Merger merger = Merger(block, program.options, tuple(program.to_merge)) - args[0].can_eliminate = False + for i in range(n_outputs): + args[i].can_eliminate = False merger.eliminate_dead_code() assert int(program.options.max_parallel_open) == 0, \ 'merging restriction not compatible with ' \ @@ -501,52 +507,105 @@ def new_instructions(self, size, regs): n_rounds template, args, self.n_rounds = self.instructions[self.merge_id()] subs = util.dict_by_id() + from Compiler import types for arg, reg in zip(args, regs): - subs[arg] = reg + if isinstance(arg, program.curr_tape.Register): + subs[arg] = reg set_global_vector_size(size) for inst in template: inst.copy(size, subs) reset_global_vector_size() + def expand_to_function(self, size, new_regs): + key = size, program.curr_tape, \ + tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \ + tuple(type(reg) for reg in new_regs) + if key not in self.functions: + from Compiler import library, types + from Compiler.GC.types import bits + class Arg: + def __init__(self, reg): + self.type = type(reg) + self.binary = isinstance(reg, bits) + self.reg = reg + # if reg is not None: + # program.base_addresses[reg] = None + def new(self): + if self.binary: + return self.type() + else: + return self.type(size=size) + def load(self): + return self.reg + def store(self, reg): + if self.type != type(None): + self.reg.update(reg) + args = [Arg(x) for x in new_regs] + @library.function_block + def f(): + res = [arg.new() for arg in args[:n_outputs]] + self.new_instructions(size, + res + [arg.load() for arg in args[n_outputs:]]) + for reg, arg in zip(res, args): + arg.store(reg) + f.name = '_'.join(['%s(%d)' % (function.__name__, size)] + + [str(x) for x in key[2]]) + self.functions[key] = f, args + f, args = self.functions[key] + for i in range(len(new_regs) - n_outputs): + args[n_outputs + i].store(new_regs[n_outputs + i]) + f() + for i in range(n_outputs): + new_regs[i].link(args[i].load()) + def expand_merged(self, skip): if function.__name__ in skip: good = True for call in self.calls: if not good: break - for arg in call[0]: - if isinstance(arg, program.curr_tape.Register) and \ - not issubclass(type(self.calls[0][0][0]), type(arg)): - good = False + for i in range(n_outputs): + for arg in call[0]: + if isinstance(arg, program.curr_tape.Register) and \ + not issubclass(type(self.calls[0][0][0]), + type(arg)): + good = False if good: - return [self], 0 + return program.curr_block.instructions.append(self) + if program.verbose: + print('expanding', self.function.__name__) tape = program.curr_tape - block = tape.BasicBlock(tape, None, None) - tape.active_basicblock = block - size = sum(call[0][0].size for call in self.calls) + tape.start_new_basicblock() + size = sum(call[0][0].vector_size() for call in self.calls) new_regs = [] for i, arg in enumerate(self.args): try: - if i == 0: - new_regs.append(type(arg)(size=size)) + if i < n_outputs: + new_regs.append(arg.new_vector(size=size)) else: new_regs.append(type(arg).concat( call[0][i] for call in self.calls)) - assert len(new_regs[-1]) == size + assert new_regs[-1].vector_size() == size except (TypeError, AttributeError): - if not isinstance(arg, int): + if not isinstance(arg, (int, type(None))): raise - break + new_regs.append(None) except: - print([call[0][0].size for call in self.calls]) + print([call[0][0].vector_size() for call in self.calls]) raise - self.new_instructions(size, new_regs) + if program.cisc_to_function and \ + (program.curr_tape.singular or program.n_running_threads): + self.expand_to_function(size, new_regs) + else: + self.new_instructions(size, new_regs) + program.curr_block.n_rounds += self.n_rounds - 1 base = 0 for call in self.calls: - reg = call[0][0] - reg.copy_from_part(new_regs[0], base, reg.size) - base += reg.size - return block.instructions, self.n_rounds - 1 + for i in range(n_outputs): + reg = call[0][i] + reg.copy_from_part(new_regs[i], base, reg.vector_size()) + base += reg.vector_size() + tape.start_new_basicblock() def add_usage(self, *args): pass @@ -605,10 +664,10 @@ def wrapper(*args, **kwargs): from Compiler import types if not (program.options.cisc and isinstance(args[0], types._register)): return function(*args, **kwargs) - if isinstance(args[0], types._clear): - res_type = type(args[1]) - else: - res_type = type(args[0]) + for arg in args: + if isinstance(arg, types._secret): + res_type = type(arg) + break res = res_type(size=args[0].size) instruction(res, *args, **kwargs) return res @@ -642,6 +701,24 @@ def wrapper(*args, **kwargs): copy_doc(wrapper, function) return wrapper +bit_instructions = {} + +def bit_cisc(function): + def wrapper(a, k, m, *args, **kwargs): + key = function, m + if key not in bit_instructions: + def instruction(*args, **kwargs): + res = function(*args[m:], **kwargs) + for x, y in zip(res, args): + x.mov(y, x) + instruction.__name__ = '%s(%d)' % (function.__name__, m) + bit_instructions[key] = cisc(instruction, m) + from Compiler.types import sintbit + res = [sintbit() for i in range(m)] + bit_instructions[function, m](*res, a, k, m, *args, **kwargs) + return res + return wrapper + class RegType(object): """ enum-like static class for Register types """ ClearModp = 'c' @@ -793,6 +870,23 @@ def __init__(self, f): def __str__(self): return self.str +class VarString(ArgFormat): + @classmethod + def check(cls, arg): + if not isinstance(arg, str): + raise ArgumentError(arg, 'Argument is not string') + + @classmethod + def encode(cls, arg): + return int_to_bytes(len(arg)) + list(bytearray(arg, 'ascii')) + + def __init__(self, f): + length = IntArgFormat(f).i + self.str = str(f.read(length), 'ascii') + + def __str__(self): + return self.str + ArgFormats = { 'c': ClearModpAF, 's': SecretModpAF, @@ -810,6 +904,7 @@ def __str__(self): 'long': LongArgFormat, 'p': PlayerNoAF, 'str': String, + 'varstr': VarString, } def format_str_is_reg(format_str): @@ -930,7 +1025,7 @@ def merge(self, other): self.args += other.args def expand_vector_args(self): - if self.is_vec(): + if self.is_vec() and self.get_size() != 1: for arg in self.args: arg.create_vector_elements() res = sum(list(zip(*self.args)), ()) @@ -939,7 +1034,7 @@ def expand_vector_args(self): return self.args def expand_merged(self, skip): - return [self], 0 + program.curr_block.instructions.append(self) def get_new_args(self, size, subs): new_args = [] @@ -956,6 +1051,10 @@ def get_new_args(self, size, subs): new_args.append(arg) return new_args + def copy(self, *args, **kwargs): + raise CompilerError("%s instruction not compatible with CISC-style " + "merging. Compile with '-O'." % type(self)) + @staticmethod def get_usage(args): return {} @@ -990,9 +1089,9 @@ def __init__(self, f): pass read = lambda: struct.unpack('>I', f.read(4))[0] full_code = struct.unpack('>Q', f.read(8))[0] - code = full_code % (1 << Instruction.code_length) + self.code = full_code % (1 << Instruction.code_length) self.size = full_code >> Instruction.code_length - self.type = cls.reverse_opcodes[code] + self.type = cls.reverse_opcodes[self.code] t = self.type name = t.__name__ try: @@ -1044,6 +1143,10 @@ class VectorInstruction(Instruction): def get_code(self): return super(VectorInstruction, self).get_code(len(self.args[0])) +class Ciscable(Instruction): + def copy(self, size, subs): + return type(self)(*self.get_new_args(size, subs), copying=True) + class DynFormatInstruction(Instruction): __slots__ = [] diff --git a/Compiler/library.py b/Compiler/library.py index 0f3303e6b..7d8c9d278 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -7,7 +7,8 @@ from Compiler.instructions import * from Compiler.util import tuplify,untuplify,is_zero from Compiler.allocator import RegintOptimizer, AllocPool -from Compiler import instructions,instructions_base,comparison,program,util +from Compiler.program import Tape +from Compiler import instructions,instructions_base,comparison,util,types import inspect,math import random import collections @@ -42,7 +43,7 @@ def vectorized_function(*args, **kwargs): def set_instruction_type(function): def instruction_typed_function(*args, **kwargs): - if len(args) > 0 and isinstance(args[0], program.Tape.Register): + if len(args) > 0 and isinstance(args[0], Tape.Register): if args[0].is_gf2n: instructions_base.set_global_instruction_type('gf2n') else: @@ -59,9 +60,15 @@ def instruction_typed_function(*args, **kwargs): def _expand_to_print(val): return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val) -def print_str(s, *args): +def print_str(s, *args, print_secrets=False): """ Print a string, with optional args for adding - variables/registers with ``%s``. """ + variables/registers with ``%s``. + + :param s: format string + :param args: arguments (any type) + :param print_secrets: whether to output secret shares + + """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ ss = bytearray(ss, 'utf8') @@ -84,11 +91,15 @@ def print_plain_str(ss): val = args[i].read() else: val = args[i] - if isinstance(val, program.Tape.Register): + if isinstance(val, Tape.Register): if val.is_clear: val.print_reg_plain() + elif print_secrets and isinstance(val, sint): + val.output() else: - raise CompilerError('Cannot print secret value:', args[i]) + raise CompilerError( + 'Cannot print secret value %s, activate printing of shares with ' + "'print_secrets=True'" % args[i]) elif isinstance(val, cfix): val.print_plain() elif isinstance(val, sfix) or isinstance(val, sfloat): @@ -100,16 +111,17 @@ def print_plain_str(ss): else: try: val.output() - except AttributeError: + except (AttributeError, TypeError): print_plain_str(str(val)) -def print_ln(s='', *args): +def print_ln(s='', *args, **kwargs): """ Print line, with optional args for adding variables/registers with ``%s``. By default only player 0 outputs, but the ``-I`` command-line option changes that. :param s: Python string with same number of ``%s`` as length of :py:obj:`args` :param args: list of public values (regint/cint/int/cfix/cfloat/localint) + :param print_secrets: whether to output secret shares Example: @@ -117,7 +129,7 @@ def print_ln(s='', *args): print_ln('a is %s.', a.reveal()) """ - print_str(str(s) + '\n', *args) + print_str(str(s) + '\n', *args, **kwargs) def print_both(s, end='\n'): """ Print line during compilation and execution. """ @@ -169,7 +181,7 @@ def print_str_if(cond, ss, *args): def print_ln_to(player, ss, *args): """ Print line at :py:obj:`player` only. Note that printing is disabled by default except at player 0. Activate interactive mode - with `-I` to enable it for all players. + with `-I` or use `-OF .` to enable it for all players. :param player: int :param ss: Python string @@ -295,8 +307,14 @@ def get_arg(): ldarg(res) return res +def get_cmdline_arg(idx): + """ Return run-time command-line argument. """ + res = regint() + cmdlinearg(res, regint.conv(idx)) + return localint(res) + def make_array(l, t=None): - if isinstance(l, program.Tape.Register): + if isinstance(l, Tape.Register): res = Array(len(l), t or type(l)) res[:] = l else: @@ -337,14 +355,15 @@ def __call__(self, *args): # first call type_args = collections.defaultdict(list) for i,arg in enumerate(args): - type_args[get_reg_type(arg)].append(i) + if not isinstance(arg, types._vectorizable): + type_args[get_reg_type(arg)].append(i) def wrapped_function(*compile_args): base = get_arg() bases = dict((t, regint.load_mem(base + i)) \ for i,t in enumerate(sorted(type_args, key=lambda x: x.reg_type))) - runtime_args = [None] * len(args) + runtime_args = list(args) for t in sorted(type_args, key=lambda x: x.reg_type): i = 0 for i_arg in type_args[t]: @@ -407,13 +426,14 @@ def unmemorize(x): class FunctionBlock(Function): def on_first_call(self, wrapped_function): + p_return_address = get_tape().program.malloc(1, 'ci') old_block = get_tape().active_basicblock - parent_node = get_tape().req_node + parent_node = old_block.req_node get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) block = get_tape().active_basicblock - block.alloc_pool = AllocPool() + block.alloc_pool = AllocPool(parent=block.alloc_pool) del parent_node.children[-1] - self.node = get_tape().req_node + self.node = block.req_node if get_program().verbose: print('Compiling function', self.name) result = wrapped_function(*self.compile_args) @@ -423,7 +443,6 @@ def on_first_call(self, wrapped_function): self.result = None if get_program().verbose: print('Done compiling function', self.name) - p_return_address = get_tape().program.malloc(1, 'ci') get_tape().function_basicblocks[block] = p_return_address return_address = regint.load_mem(p_return_address) get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False)) @@ -446,7 +465,7 @@ def on_call(self, base, bases): return_address.store_in_mem(p_return_address) get_tape().start_new_basicblock(name='call-' + self.name) get_tape().active_basicblock.set_return(old_block, self.last_sub_block) - get_tape().req_node.children.append(self.node) + get_block().req_node.children.append(self.node) if self.result is not None: return unmemorize(self.result) @@ -654,7 +673,7 @@ def loop_fn(i): and isinstance(step, int): # known loop count if condition(start): - get_tape().req_node.children[-1].aggregator = \ + get_block().req_node.children[-1].aggregator = \ lambda x: int(ceil(((stop - start) / step))) * x[0] def for_range(start, stop=None, step=None): @@ -680,9 +699,6 @@ def _(i): x.update(x + 1) print_ln('%s', x.reveal()) - Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop. Use - :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -791,7 +807,7 @@ def decorator(loop_body): loop_rounds = n_loops // n_parallel \ if n_parallel < n_loops else 0 else: - loop_rounds = n_loops / n_parallel + loop_rounds = n_loops // n_parallel def write_state_to_memory(r): if use_array: mem_state.assign(r) @@ -821,6 +837,8 @@ def f(i): n_opt_loops_reg = regint(0) n_opt_loops_inst = get_block().instructions[-1] parent_block = get_block() + prevent_breaks = get_program().prevent_breaks + get_program().prevent_breaks = False @while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0)) def _(i): state = tuplify(initializer()) @@ -846,6 +864,7 @@ def _(i): loop_rounds = n_loops // my_n_parallel blocks = get_tape().basicblocks n_to_merge = 5 + get_program().prevent_breaks = prevent_breaks if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]: # merge blocks started by if and do_while def exit_elimination(block): @@ -857,19 +876,22 @@ def exit_elimination(block): merged.exit_condition = blocks[-1].exit_condition merged.exit_block = blocks[-1].exit_block assert parent_block is blocks[-n_to_merge] - assert blocks[-n_to_merge + 1] is \ - get_tape().req_node.children[-1].nodes[0].blocks[0] + assert blocks[-n_to_merge + 1].req_node is \ + get_block().req_node.children[-1].nodes[0] for block in blocks[-n_to_merge + 1:]: merged.instructions += block.instructions exit_elimination(block) block.purge(retain_usage=False) del blocks[-n_to_merge + 1:] - del get_tape().req_node.children[-1] + del get_block().req_node.children[-1] merged.children = [] RegintOptimizer().run(merged.instructions, get_program()) get_tape().active_basicblock = merged else: - req_node = get_tape().req_node.children[-1].nodes[0] + if get_program().verbose: + print(n_opt_loops, 'repetitions') + assert not get_program().prevent_breaks + req_node = get_block().req_node.children[-1].nodes[0] if util.is_constant(loop_rounds): req_node.children[0].aggregator = lambda x: loop_rounds * x[0] if isinstance(n_loops, int): @@ -892,7 +914,8 @@ def returner(): return returner return decorator -def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}): +def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}, + budget=None): """ Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads` threads, up to :py:obj:`n_parallel` in parallel per thread. @@ -902,9 +925,10 @@ def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}): """ return map_reduce(n_threads, n_parallel, n_loops, \ - lambda *x: [], lambda *x: [], thread_mem_req) + lambda *x: [], lambda *x: [], thread_mem_req, + budget=budget) -def for_range_opt_multithread(n_threads, n_loops): +def for_range_opt_multithread(n_threads, n_loops, budget=None): """ Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads` threads, in parallel up to an optimization budget per thread @@ -943,7 +967,7 @@ def _(i): b = a + 1 """ - return for_range_multithread(n_threads, None, n_loops) + return for_range_multithread(n_threads, None, n_loops, budget=budget) def multithread(n_threads, n_items=None, max_size=None): """ @@ -983,7 +1007,7 @@ def _(i): return wrapper def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ - thread_mem_req={}, looping=True): + thread_mem_req={}, looping=True, budget=None): assert(n_threads != 0) if isinstance(n_loops, (list, tuple)): split = n_loops @@ -1025,10 +1049,13 @@ def decorator(loop_body): state_type = type(state[0]) else: state_type = type(state) + prevent_breaks = get_program().prevent_breaks def f(inc): + get_program().prevent_breaks = prevent_breaks base = args[get_arg()][0] + get_program().base_addresses[base] = None if not util.is_constant(thread_rounds): - i = base / thread_rounds + i = base // thread_rounds overhang = n_loops % n_threads inc = i < overhang base += inc.if_else(i, overhang) @@ -1050,6 +1077,7 @@ def f(i): if prog.curr_tape == prog.tapes[0]: prog.n_running_threads = n_threads if not util.is_zero(thread_rounds): + prog.prevent_breaks = False tape = prog.new_tape(f, (0,), 'multithread') for i in range(n_threads - remainder): mem_state = make_array(initializer()) @@ -1058,6 +1086,7 @@ def f(i): args[remainder + i][1] = mem_state.address thread_args.append((tape, remainder + i)) if remainder: + prog.prevent_breaks = False tape1 = prog.new_tape(f, (1,), 'multithread1') for i in range(remainder): mem_state = make_array(initializer()) @@ -1066,10 +1095,12 @@ def f(i): args[i][1] = mem_state.address thread_args.append((tape1, i)) prog.n_running_threads = None + prog.prevent_breaks = False threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) prog.free_later() + prog.prevent_breaks = prevent_breaks if len(state): if thread_rounds: for i in range(n_threads - remainder): @@ -1266,7 +1297,7 @@ def _link(pre, g): if g: from .types import _single for name, var in pre.items(): - if isinstance(var, (program.Tape.Register, _single, _vec)): + if isinstance(var, (Tape.Register, _single, _vec)): new_var = g[name] if util.is_constant_float(new_var): raise CompilerError('cannot reassign constants in blocks') @@ -1285,7 +1316,7 @@ def _(): return regint(0) """ scope = instructions.program.curr_block - parent_node = get_tape().req_node + parent_node = get_block().req_node # possibly unknown loop count get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \ name='begin-loop') @@ -1334,9 +1365,10 @@ def else_then(): raise CompilerError('else block already defined') # run the else block state.if_exit_block = instructions.program.curr_block - state.req_child.add_node(get_tape(), 'else-block') + req_node = state.req_child.add_node(get_tape(), 'else-block') instructions.program.curr_tape.start_new_basicblock(state.start_block, \ - name='else-block') + name='else-block', + req_node=req_node) state.else_block = instructions.program.curr_block state.has_else = True @@ -1545,6 +1577,23 @@ def accept_client_connection(port): instructions.acceptclientconnection(res, regint.conv(port)) return res +def init_client_connection(host, port, my_id, relative_port=True): + """ Initiate connection to another party as client. + + :param host: hostname + :param port: port base (int/regint/cint) + :param my_id: client id to use + :param relative_port: whether to add party number to port number + :returns: connection id + + """ + if relative_port: + port = (port + get_player_id())._v + res = regint() + instructions.initclientconnection( + res, regint.conv(port), regint.conv(my_id), host) + return res + def break_point(name=''): """ Insert break point. This makes sure that all following code @@ -1643,7 +1692,9 @@ def cint_cint_division(a, b, k, f): return (sign_a * sign_b) * A from Compiler.program import Program -def sint_cint_division(a, b, k, f, kappa): + +@instructions_base.ret_cisc +def sint_cint_division(a, b, k, f, kappa, nearest=False): """ type(a) = sint, type(b) = cint """ @@ -1659,12 +1710,11 @@ def sint_cint_division(a, b, k, f, kappa): B = absolute_b W = w0 - @for_range(1, theta) - def block(i): - A.link(TruncPr(A * W, 2*k, f, kappa)) - temp = (B * W) >> f - W.link(two - temp) - B.link(temp) + for i in range(1, theta): + A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True) + temp = (B * W + 2 * (f - 1)) >> f + W = two - temp + B = temp return (sign_a * sign_b) * A def IntDiv(a, b, k, kappa=None): @@ -1691,13 +1741,16 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): assert 2 * f > k - nearest theta = int(ceil(log(k/3.5) / log(2))) + l_y = k + 3 * f - res_f + comparison.require_ring_size( + l_y, 'division (https://www.ifca.ai/pub/fc10/31_47.pdf)') + base.set_global_vector_size(b.size) alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() - l_y = k + 3 * f - res_f y = a.extend(l_y) * w y = y.round(l_y, f, kappa, nearest, signed=True) diff --git a/Compiler/ml.py b/Compiler/ml.py index 98677e1fc..6f34f595d 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -834,7 +834,7 @@ def compute_f_input(self, batch): prod = MultiArray([N, self.d, self.d_out], sfix) else: prod = self.f_input - max_size = program.Program.prog.budget // self.d_out + max_size = get_program().budget // self.d_out @multithread(self.n_threads, N, max_size) def _(base, size): X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) @@ -1038,16 +1038,16 @@ def __init__(self, shape, inputs=None): self.inputs = inputs def f_part(self, base, size): - return self.f(self.X.get_part_vector(base, size)) + return self.f(self.X.get_vector(base, size)) def f_prime_part(self, base, size): return self.f_prime(self.Y.get_vector(base, size)) def _forward(self, batch=[0]): n_per_item = reduce(operator.mul, self.X.sizes[1:]) - @multithread(self.n_threads, len(batch), max(1, 1000 // n_per_item)) + @multithread(self.n_threads, len(batch) * n_per_item) def _(base, size): - self.Y.assign_part_vector(self.f_part(base, size), base) + self.Y.assign_vector(self.f_part(base, size), base) if self.debug_output: name = self @@ -1095,9 +1095,9 @@ def __init__(self, shape, inputs=None): self.comparisons = MultiArray(shape, sint) def f_part(self, base, size): - x = self.X.get_part_vector(base, size) + x = self.X.get_vector(base, size) c = x > 0 - self.comparisons.assign_part_vector(c, base) + self.comparisons.assign_vector(c, base) return c.if_else(x, 0) def f_prime_part(self, base, size): @@ -1686,12 +1686,9 @@ def _forward(self, batch): padding_h, padding_w = self.padding if self.use_conv2ds: - n_parts = max(1, round((self.n_threads or 1) / n_channels_out)) - while len(batch) % n_parts != 0: - n_parts -= 1 - print('Convolution in %d parts' % n_parts) - part_size = len(batch) // n_parts - @for_range_multithread(self.n_threads, 1, [n_parts, n_channels_out]) + part_size = 1 + @for_range_opt_multithread(self.n_threads, + [len(batch), n_channels_out]) def _(i, j): inputs = self.X.get_slice_vector( batch.get_part(i * part_size, part_size)) @@ -2507,6 +2504,10 @@ def _(j): loss = self.layers[-1].average_loss(N) res = (loss < stop_on_loss) * (loss >= -1) self.stopped_on_loss.write(1 - res) + print_ln_if( + self.stopped_on_loss, + 'aborting epoch because loss is outside range: %s', + loss) return res if self.print_losses: print_ln() @@ -2545,7 +2546,7 @@ def reveal_correctness(self, data, truth, batch_size=128, running=False): loss = MemValue(sfix(0)) def f(start, batch_size, batch): batch.assign_vector(regint.inc(batch_size, start)) - self.forward(batch=batch) + self.forward(batch=batch, run_last=False) part_truth = truth.get_part(start, batch_size) n_correct.iadd( self.layers[-1].reveal_correctness(batch_size, part_truth)) @@ -2644,7 +2645,7 @@ def _(i): batch = Array.create_from(regint.inc(batch_size)) self.forward(batch=batch, training=True) self.backward(batch=batch) - self.update(0, batch=batch) + self.update(0, batch=batch, i_batch=0) return @for_range(n_runs) def _(i): @@ -2697,6 +2698,8 @@ def _(): if depreciation: self.gamma.imul(depreciation) print_ln('reducing learning rate to %s', self.gamma) + print_ln_if(self.stopped_on_low_loss, + 'aborting run because of low loss') return 1 - self.stopped_on_low_loss if self.missing_newline: print_ln('') diff --git a/Compiler/oram.py b/Compiler/oram.py index b370e9c3c..6f70b7b70 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -20,7 +20,7 @@ from functools import reduce from Compiler.types import * -from Compiler.types import _secret +from Compiler.types import _secret, _register from Compiler.library import * from Compiler.program import Program from Compiler import floatingpoint,comparison,permutation @@ -77,8 +77,8 @@ def __init__(self, value, start, lengths, entries_per_block): self.lower, self.shift = \ floatingpoint.Trunc(self.value, self.n_bits, self.start, \ Program.prog.security, True) - trunc = (self.value - self.lower) / self.shift - self.slice = trunc.mod2m(length, self.n_bits, False) + trunc = (self.value - self.lower).field_div(self.shift) + self.slice = trunc.mod2m(length, self.n_bits, signed=False) self.upper = (trunc - self.slice) * self.shift def get_slice(self): total_length = sum(self.lengths) @@ -89,13 +89,11 @@ def get_slice(self): res = [] remainder = self.slice for length,start in zip(self.lengths[:-1],series(self.lengths)): - res.append(remainder.mod2m(length, total_length - start, False)) + res.append(remainder.mod2m(length, total_length - start, + signed=False)) remainder -= res[-1] - if Program.prog.options.ring: - remainder = remainder.trunc_zeros(length, - total_length - start, False) - else: - remainder /= floatingpoint.two_power(length) + remainder = remainder.trunc_zeros(length, + total_length - start, False) res.append(remainder) return res def set_slice(self, value): @@ -208,23 +206,39 @@ def demux_list(x): return res def demux_array(x, res=None): + tmp = demux_matrix(x).array + if res: + try: + assert issubclass(x.value_type, _register) + res[:] = tmp[:] + except: + @for_range(len(res)) + def _(i): + res[i] = tmp[i] + else: + res = tmp + return res + +def demux_matrix(x, n_threads=None): n = len(x) - if res is None: - res = Array(2**n, type(x[0])) + if n == 0: + return [1] + m = len(x[0]) + t = type(x[0]) + res = Matrix(2**n, m, type(x[0])) if n == 1: res[0] = 1 - x[0] res[1] = x[0] else: - a = Array(2**(n//2), type(x[0])) + a = Matrix(2**(n//2), m, type(x[0])) a.assign(demux(x[:n//2])) - b = Array(2**(n-n//2), type(x[0])) + b = Matrix(2**(n-n//2), m, type(x[0])) b.assign(demux(x[n//2:])) - @for_range_multithread(get_n_threads(len(res)), \ - max(1, n_parallel // len(b)), len(a)) + @for_range_opt_multithread(n_threads, len(a)) def f(i): - @for_range_parallel(n_parallel, len(b)) + @for_range_opt(len(b)) def f(j): - res[j * len(a) + i] = a[i] * b[j] + res[j * len(a) + i][:] = a[i][:] * b[j][:] return res def get_first_one(x): @@ -1717,7 +1731,8 @@ def delete(self, *args, **kwargs): def OptimalORAM(size,*args,**kwargs): """ Create an ORAM instance suitable for the size based on - experiments. + experiments. This uses the approach by `Keller and Scholl + `_. :param size: number of elements :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / diff --git a/Compiler/program.py b/Compiler/program.py index 00503528f..119caafb6 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -11,6 +11,7 @@ import re import sys import hashlib +import random from collections import defaultdict, deque from functools import reduce @@ -74,8 +75,8 @@ class Program(object): """A program consists of a list of tapes representing the whole computation. - When compiling an :file:`.mpc` file, the single instances is - available as :py:obj:`program` in order. When compiling directly + When compiling an :file:`.mpc` file, the single instance is + available as :py:obj:`program`. When compiling directly from Python code, an instance has to be created before running any instructions. """ @@ -89,6 +90,7 @@ def __init__(self, args, options=defaults, name=None): self.name = name self.init_names(args) self._security = 40 + self.used_security = 0 self.prime = None self.tapes = [] if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: @@ -113,7 +115,8 @@ def __init__(self, args, options=defaults, name=None): if not self.bit_length: self.bit_length = 64 print("Default bit length for compilation:", self.bit_length) - print("Default security parameter for compilation:", self.security) + if not (options.binary or options.garbled): + print("Default security parameter for compilation:", self._security) self.galois_length = int(options.galois) if self.verbose: print("Galois length:", self.galois_length) @@ -185,10 +188,13 @@ def __init__(self, args, options=defaults, name=None): self.relevant_opts = set() self.n_running_threads = None self.input_files = {} - self.base_addresses = {} + self.base_addresses = util.dict_by_id() self._protect_memory = False + self.mem_protect_stack = [] self._always_active = True self.active = True + self.prevent_breaks = False + self.cisc_to_function = True if not self.options.cisc: self.options.cisc = not self.options.optimize_hard @@ -249,7 +255,7 @@ def init_names(self, args): else: raise CompilerError( "found none of the potential input files: " + - ", ".join("'%s'" % x for x in [args[0]] + infiles)) + ", ".join("'%s'" % x for x in infiles)) """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames @@ -352,7 +358,8 @@ def run_tapes(self, args): ) self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: - self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree) + self.curr_block.req_node.children.append( + self.tapes[arg[0]].req_tree) return thread_numbers def join_tape(self, thread_number): @@ -400,6 +407,7 @@ def write_bytes(self): sch_file.write("lgp:%s" % req) sch_file.write("\n") sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) + sch_file.write("sec:%d\n" % self.used_security) sch_file.close() h = hashlib.sha256() for tape in self.tapes: @@ -433,7 +441,7 @@ def curr_block(self): """The basic block that is currently being created.""" return self.curr_tape.active_basicblock - def malloc(self, size, mem_type, reg_type=None, creator_tape=None): + def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=True): """Allocate memory from the top""" if not isinstance(size, int): raise CompilerError("size must be known at compile time") @@ -456,7 +464,7 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): else: raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] - addr = blocks.pop(size) + addr = blocks.pop(size) if use_freed else None if addr is not None: self.saved += size else: @@ -469,11 +477,13 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool if single_size: from .library import get_thread_number, runtime_error_if - + bak = self.curr_tape.active_basicblock + self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0] tn = get_thread_number() runtime_error_if(tn > self.n_running_threads, "malloc") res = addr + single_size * (tn - 1) - self.base_addresses[str(res)] = addr + self.curr_tape.active_basicblock = bak + self.base_addresses[res] = addr return res else: return addr @@ -482,7 +492,7 @@ def free(self, addr, mem_type): """Free memory""" now = True if not util.is_constant(addr): - addr = self.base_addresses[str(addr)] + addr = self.base_addresses[addr] now = self.curr_tape == self.tapes[0] size, pool = self.allocated_mem_blocks[addr, mem_type] if self.curr_block.alloc_pool is not pool: @@ -524,7 +534,8 @@ def finalize(self): self.public_input_file.close() def finalize_memory(self): - self.curr_tape.start_new_basicblock(None, "memory-usage") + self.curr_tape.start_new_basicblock(None, "memory-usage", + req_node=self.curr_tape.req_tree) # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() @@ -575,6 +586,7 @@ def set_security(self, security): def security(self): """The statistical security parameter for non-linear functions.""" + self.used_security = max(self.used_security, self._security) return self._security @security.setter @@ -701,6 +713,13 @@ def protect_memory(self, status): """ Enable or disable memory protection. """ self._protect_memory = status + def open_memory_scope(self, key=None): + self.mem_protect_stack.append(self._protect_memory) + self.protect_memory(key or object()) + + def close_memory_scope(self): + self.protect_memory(self.mem_protect_stack.pop()) + def use_cisc(self): return self.options.cisc and (not self.prime or self.rabbit_gap()) \ and not self.options.max_parallel_open @@ -725,7 +744,7 @@ def semi_honest(self): self._always_active = False @staticmethod - def read_tapes(schedule): + def read_schedule(schedule): m = re.search(r"([^/]*)\.mpc", schedule) if m: schedule = m.group(1) @@ -733,7 +752,7 @@ def read_tapes(schedule): schedule = "Programs/Schedules/%s.sch" % schedule try: - lines = open(schedule).readlines() + return open(schedule).readlines() except FileNotFoundError: print( "%s not found, have you compiled the program?" % schedule, @@ -741,9 +760,25 @@ def read_tapes(schedule): ) sys.exit(1) + @classmethod + def read_tapes(cls, schedule): + lines = cls.read_schedule(schedule) for tapename in lines[2].split(" "): yield tapename.strip().split(":")[0] + @classmethod + def read_n_threads(cls, schedule): + return int(cls.read_schedule(schedule)[0]) + + @classmethod + def read_domain_size(cls, schedule): + from Compiler.instructions import reqbl_class + tapename = cls.read_schedule(schedule)[2].strip().split(":")[0] + for inst in Tape.read_instructions(tapename): + if inst.code == reqbl_class.code: + bl = inst.args[0] + return (abs(bl.i) + 63) // 64 * 8 + class Tape: """A tape contains a list of basic blocks, onto which instructions are added.""" @@ -755,13 +790,12 @@ def __init__(self, name, program): self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) - self.req_node = self.req_tree self.basicblocks = [] self.purged = False self.block_counter = 0 self.active_basicblock = None self.old_allocated_mem = program.allocated_mem.copy() - self.start_new_basicblock() + self.start_new_basicblock(req_node=self.req_tree) self._is_empty = False self.merge_opens = True self.if_states = [] @@ -774,7 +808,8 @@ def __init__(self, name, program): self.warned_about_mem = False class BasicBlock(object): - def __init__(self, parent, name, scope, exit_condition=None): + def __init__(self, parent, name, scope, exit_condition=None, + req_node=None): self.parent = parent self.instructions = [] self.name = name @@ -794,6 +829,8 @@ def __init__(self, parent, name, scope, exit_condition=None): self.n_to_merge = 0 self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] + self.req_node = req_node + self.used_from_scope = set() def __len__(self): return len(self.instructions) @@ -860,17 +897,25 @@ def add_usage(self, req_node): req_node.num += self.rounds def expand_cisc(self): - new_instructions = [] if self.parent.program.options.keep_cisc is not None: - skip = ["LTZ", "Trunc"] + skip = ["LTZ", "Trunc", "EQZ"] skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] + tape = self.parent + tape.start_new_basicblock(scope=self.scope, req_node=self.req_node, + name="cisc") + start_block = tape.basicblocks[-1] + start_block.alloc_pool = self.alloc_pool for inst in self.instructions: - new_inst, n_rounds = inst.expand_merged(skip) - new_instructions.extend(new_inst) - self.n_rounds += n_rounds - self.instructions = new_instructions + inst.expand_merged(skip) + self.instructions = tape.active_basicblock.instructions + if start_block == tape.basicblocks[-1]: + res = self + else: + res = start_block + tape.basicblocks[-1] = self + return res def __str__(self): return self.name @@ -885,7 +930,8 @@ def is_empty(self): self._is_empty = len(self.basicblocks) == 0 return self._is_empty - def start_new_basicblock(self, scope=False, name=""): + def start_new_basicblock(self, scope=False, name="", req_node=None): + assert not self.program.prevent_breaks if self.program.verbose and self.active_basicblock and \ self.program.allocated_mem != self.old_allocated_mem: print("New allocated memory in %s " % self.active_basicblock.name, @@ -900,10 +946,12 @@ def start_new_basicblock(self, scope=False, name=""): scope = self.active_basicblock suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 - sub = self.BasicBlock(self, self.name + "-" + suffix, scope) + if req_node is None: + req_node = self.active_basicblock.req_node + sub = self.BasicBlock(self, self.name + "-" + suffix, scope, + req_node=req_node) self.basicblocks.append(sub) self.active_basicblock = sub - self.req_node.add_block(sub) # print 'Compiling basic block', sub.name def init_registers(self): @@ -1054,12 +1102,20 @@ def optimize(self, options): print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) + # make addresses available in functions + for addr in self.program.base_addresses: + if addr.program == self and self.basicblocks: + allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool) + + seen = set() + def alloc(block): allocator.update_usage(block.alloc_pool) for reg in sorted( block.used_from_scope, key=lambda x: (x.reg_type, x.i) ): allocator.alloc_reg(reg, block.alloc_pool) + seen.add(block) def alloc_loop(block): left = deque([block]) @@ -1067,7 +1123,8 @@ def alloc_loop(block): block = left.popleft() alloc(block) for child in block.children: - left.append(child) + if child not in seen: + left.append(child) allocator.old_pool = None for i, block in enumerate(reversed(self.basicblocks)): @@ -1101,6 +1158,8 @@ def alloc_loop(block): # offline data requirements if self.program.verbose: print("Compile offline data requirements...") + for block in self.basicblocks: + block.req_node.add_block(block) self.req_num = self.req_tree.aggregate() if self.program.verbose: print("Tape requires", self.req_num) @@ -1160,8 +1219,24 @@ def alloc_loop(block): @unpurged def expand_cisc(self): + mapping = {None: None} + blocks = self.basicblocks[:] + self.basicblocks = [] + for block in blocks: + expanded = block.expand_cisc() + mapping[block] = expanded + for block in self.basicblocks: + if block not in mapping: + mapping[block] = block for block in self.basicblocks: - block.expand_cisc() + block.exit_block = mapping[block.exit_block] + if block.exit_block is not None: + assert block.exit_block in self.basicblocks + if block.previous_block and mapping[block] != block: + mapping[block].previous_block = block.previous_block + mapping[block].sub_block = block.sub_block + block.previous_block = None + del block.sub_block @unpurged def _get_instructions(self): @@ -1320,27 +1395,38 @@ def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ["num", "children", "name", "blocks"] + __slots__ = ["num", "_children", "name", "blocks", "aggregated"] def __init__(self, name): - self.children = [] + self._children = [] self.name = name self.blocks = [] + self.aggregated = None + + @property + def children(self): + self.aggregated = None + return self._children def aggregate(self, *args): + if self.aggregated is not None: + return self.aggregated self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) res = reduce( lambda x, y: x + y.aggregate(self.name), self.children, self.num ) + self.aggregated = res return res def increment(self, data_type, num=1): self.num[data_type] += num + self.aggregated = None def add_block(self, block): self.blocks.append(block) + self.aggregated = None class ReqChild(object): __slots__ = ["aggregator", "nodes", "parent"] @@ -1369,18 +1455,18 @@ def aggregate(self, name): def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) - tape.req_node = new_node + return new_node def open_scope(self, aggregator, scope=False, name=""): - child = self.ReqChild(aggregator, self.req_node) - self.req_node.children.append(child) - child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) - self.start_new_basicblock(name=name) + req_node = self.active_basicblock.req_node + child = self.ReqChild(aggregator, req_node) + req_node.children.append(child) + node = child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) + self.start_new_basicblock(name=name, req_node=node) return child def close_scope(self, outer_scope, parent_req_node, name): - self.req_node = parent_req_node - self.start_new_basicblock(outer_scope, name) + self.start_new_basicblock(outer_scope, name, req_node=parent_req_node) def require_bit_length(self, bit_length, t="p"): if t == "p": @@ -1553,7 +1639,7 @@ def update(self, other): diff_block = isinstance(other, Tape.Register) and self.block != other.block other = type(self)(other) if not diff_block: - self.program.start_new_basicblock() + self.program.start_new_basicblock(name="update") if self.program != other.program: raise CompilerError( 'cannot update register with one from another thread') @@ -1575,6 +1661,8 @@ def is_clear(self): ) def __str__(self): - return self.reg_type + str(self.i) + return self.reg_type + str(self.i) + \ + ("(%d)" % self.size if self.size is not None and self.size > 1 + else "") __repr__ = __str__ diff --git a/Compiler/types.py b/Compiler/types.py index e7e8aeefc..afff7f4cf 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -20,7 +20,7 @@ Basic types ----------- -All basic can be used as vectors, that is one instance representing +All basic types can be used as vectors, that is one instance representing several values, with all operations being executed element-wise. For example, the following computes ten multiplications of integers input by party 0 and 1:: @@ -588,7 +588,7 @@ def input_tensor_from_client(cls, client_id, shape): @classmethod def input_tensor_via(cls, player, content=None, shape=None, binary=True, - one_hot=False): + one_hot=False, skip_input=False, n_bytes=None): """ Input tensor-like data via a player. This overwrites the input file for the relevant player. The following returns an @@ -630,7 +630,10 @@ def input_tensor_via(cls, player, content=None, shape=None, binary=True, else: t = numpy.single else: - t = numpy.int64 + if n_bytes == 1: + t = numpy.int8 + else: + t = numpy.int64 if one_hot: content = numpy.eye(content.max() + 1)[content] content = content.astype(t) @@ -666,9 +669,10 @@ def traverse(content, level): if requested_shape is not None and \ list(shape) != list(requested_shape): raise CompilerError('content contradicts shape') - res = cls.Tensor(shape) - res.input_from(player, binary=binary) - return res + if not skip_input: + res = cls.Tensor(shape) + res.input_from(player, binary=binary, n_bytes=n_bytes) + return res class _vec(Tape._no_truth): def link(self, other): @@ -681,6 +685,13 @@ class _register(Tape.Register, _number, _structure): def n_elements(): return 1 + @classmethod + def new_vector(cls, size): + return cls(size=size) + + def vector_size(self): + return self.size + @vectorized_classmethod def conv(cls, val): if isinstance(val, MemValue): @@ -707,7 +718,7 @@ def hard_conv(cls, val): except AttributeError: try: return type(val)(cls.hard_conv(v) for v in val) - except TypeError: + except (TypeError, CompilerError): pass return cls(val) @@ -756,11 +767,11 @@ def bit_compose(cls, bits): return sum(cls.conv(b) << i for i,b in enumerate(bits)) @classmethod - def malloc(cls, size, creator_tape=None): + def malloc(cls, size, creator_tape=None, **kwargs): """ Allocate memory (statically). :param size: compile-time (int) """ - return program.malloc(size, cls, creator_tape=creator_tape) + return program.malloc(size, cls, creator_tape=creator_tape, **kwargs) @classmethod def free(cls, addr): @@ -783,7 +794,11 @@ def __init__(self, reg_type, val, size): else: self[i].load_other(x) elif val is not None: - self.load_other(val) + try: + self.load_other(val) + except: + raise CompilerError( + "cannot convert '%s' to '%s'" % (type(val), type(self))) def _new_by_number(self, i, size=1): res = type(self)(size=size) @@ -943,16 +958,15 @@ def __rsub__(self, other): return self.clear_op(other, subc, subcfi, True) __rsub__.__doc__ = __sub__.__doc__ - def __truediv__(self, other): + def field_div(self, other): """ Field division of public values. Not available for computation modulo a power of two. :param other: convertible type (at least same as :py:obj:`self` and regint/int) """ - return self.clear_op(other, divc, divci) - - def __rtruediv__(self, other): - return self.coerce_op(other, divc, True) - __rtruediv__.__doc__ = __truediv__.__doc__ + try: + return other._rfield_div(self) + except AttributeError: + return self.clear_op(other, divc, divci) def __and__(self, other): """ Bit-wise AND of public values. @@ -1107,6 +1121,20 @@ def __floordiv__(self, other): def __rfloordiv__(self, other): return self.coerce_op(other, floordivc, True) + def __truediv__(self, other): + """ Clear fixed-point division. + + :param other: any compatible type """ + if isinstance(other, cint): + return other.__rtruediv__(self) + try: + return cfix._new(self) / cfix._new(cint(other)) + except: + return NotImplemented + + def __rtruediv__(self, other): + return cfix._new(other) / cfix._new(self) + @vectorize def less_than(self, other, bit_length): """ Clear comparison for particular bit length. @@ -1348,6 +1376,11 @@ def __neg__(self): """ Identity. """ return self + __truediv__ = _clear.field_div + + def __rtruediv__(self, other): + return self.coerce_op(other, divc, True) + @vectorize def __invert__(self): """ Clear bit-wise inversion. """ @@ -1594,8 +1627,14 @@ def __rfloordiv__(self, other): return self.int_op(other, divint, True) __rfloordiv__.__doc__ = __floordiv__.__doc__ - __truediv__ = __floordiv__ - __rtruediv__ = __rfloordiv__ + def __truediv__(self, other): + if isinstance(other, _gf2n): + return NotImplemented + else: + return cint(self) / other + + def __rtruediv__(self, other): + return other / cint(self) def __mod__(self, other): """ Clear modulo computation. @@ -1603,7 +1642,7 @@ def __mod__(self, other): :param other: regint/cint/int """ if util.is_constant(other) and other >= 2 ** 64: return self - return self - (self / other) * other + return self - (self // other) * other def __rmod__(self, other): """ Clear modulo computation. @@ -1661,7 +1700,7 @@ def __lshift__(self, other): def __rshift__(self, other): if isinstance(other, int): - return self / 2**other + return self // 2**other else: return self.cint_op(other, operator.rshift) @@ -1793,6 +1832,9 @@ def output(self): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) + __add__ = lambda self, other: localint(self._v + other) + __radd__ = lambda self, other: localint(self._v + other) + class personal(Tape._no_truth): """ Value known to one player. Supports operations with public values and personal values known to the same player. Can be used @@ -1812,7 +1854,7 @@ def __init__(self, player, value): self._v = value @classmethod - def read_int(cls, player): + def read_int(cls, player, n_bytes=None): """ Read integer from ``Player-Data/Input-Binary-P-`` only on party :py:obj:`player`. @@ -1822,7 +1864,7 @@ def read_int(cls, player): """ tmp = cint() - fixinput(player, tmp, 0, 0) + fixinput(player, tmp, n_bytes or 0, 0) return cls(player, tmp) @classmethod @@ -2229,7 +2271,7 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - def __truediv__(self, other): + def field_div(self, other): """ Secret field division. :param other: any compatible type """ @@ -2237,13 +2279,12 @@ def __truediv__(self, other): one = self.clear_type(1, size=other.size) except AttributeError: one = self.clear_type(1) - return self * (one / other) + return self * one.field_div(other) @vectorize - def __rtruediv__(self, other): + def _rfield_div(self, other): a,b = self.get_random_inverse() - return other * a / (a * self).reveal() - __rtruediv__.__doc__ = __truediv__.__doc__ + return other * a.field_div((a * self).reveal()) @set_instruction_type @vectorize @@ -2311,8 +2352,8 @@ class sint(_secret, _int): The following operations work as expected in the computation domain (modulo a prime or a power of two): ``+, -, *``. ``/`` - denotes the field division modulo a prime. It will reveal if the - divisor is zero. Comparisons operators (``==, !=, <, <=, >, >=``) + denotes a fixed-point division. + Comparisons operators (``==, !=, <, <=, >, >=``) assume that the element in the computation domain represents a signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and @@ -2398,14 +2439,14 @@ def get_random(cls): return res @vectorized_classmethod - def get_input_from(cls, player, binary=False): + def get_input_from(cls, player, binary=False, n_bytes=None): """ Secret input. :param player: public (regint/cint/int) :param size: vector size (int, default 1) """ if binary: - return cls(personal.read_int(player)) + return cls(personal.read_int(player, n_bytes=n_bytes)) else: res = cls() inputmixed('int', res, player) @@ -2540,6 +2581,12 @@ def write_to_socket(cls, client_id, values, """ writesockets(client_id, message_type, values[0].size, *values) + @vectorize + def write_fully_to_socket(self, client_id, + message_type=ClientMessageType.NoType): + """ Send full secret to socket """ + writesockets(client_id, message_type, self.size, self) + @vectorize def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType): """ Send only share to socket """ @@ -2557,7 +2604,9 @@ def write_shares_to_socket(cls, client_id, values, @classmethod def read_from_file(cls, start, n_items): - """ Read shares from ``Persistence/Transactions-P.data``. + """ Read shares from + ``Persistence/Transactions-P.data``. See :ref:`this + section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) :param n_items: number of items (int) @@ -2572,7 +2621,8 @@ def read_from_file(cls, start, n_items): @staticmethod def write_to_file(shares, position=None): """ Write shares to ``Persistence/Transactions-P.data`` - (appending at the end). + (appending at the end). See :ref:`this section ` + for details on the data format. :param shares: (list or iterable of sint) :param position: start position (int/regint/cint), @@ -2641,7 +2691,7 @@ def __lt__(self, other, bit_length=None, security=None): res = sintbit() comparison.LTZ(res, self - other, (bit_length or program.bit_length) + 1, - security or program.security) + security) return res @read_mem_value @@ -2651,7 +2701,7 @@ def __gt__(self, other, bit_length=None, security=None): res = sintbit() comparison.LTZ(res, other - self, (bit_length or program.bit_length) + 1, - security or program.security) + security) return res @read_mem_value @@ -2670,7 +2720,7 @@ def __ge__(self, other, bit_length=None, security=None): def __eq__(self, other, bit_length=None, security=None): return sintbit.conv( floatingpoint.EQZ(self - other, bit_length or program.bit_length, - security or program.security)) + security)) @read_mem_value @type_comp @@ -2709,7 +2759,6 @@ def mod2m(self, m, bit_length=None, security=None, signed=True): :param bit_length: bit length of input (default: global bit length) """ bit_length = bit_length or program.bit_length - security = security or program.security if isinstance(m, int): if m == 0: return 0 @@ -2737,7 +2786,7 @@ def pow2(self, bit_length=None, security=None): :param bit_length: bit length of input (default: global bit length) """ return floatingpoint.Pow2(self, bit_length or program.bit_length, \ - security or program.security) + security) def __lshift__(self, other, bit_length=None, security=None): """ Secret left shift. @@ -2756,7 +2805,6 @@ def __rshift__(self, other, bit_length=None, security=None, signed=True): :param bit_length: bit length of input (default: global bit length) """ bit_length = bit_length or program.bit_length - security = security or program.security if isinstance(other, int): if other == 0: return self @@ -2783,7 +2831,7 @@ def __rrshift__(self, other): """ Secret right shift. :param other: secret or public integer (sint/cint/regint/int) of globale bit length if secret """ - return floatingpoint.Trunc(other, program.bit_length, self, program.security) + return floatingpoint.Trunc(other, program.bit_length, self) @vectorize def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False): @@ -2791,7 +2839,7 @@ def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False): if bit_length == 0: return [] bit_length = bit_length or program.bit_length - assert program.security == security or program.security + program.non_linear.check_security(security) return program.non_linear.bit_dec(self, bit_length, bit_length, maybe_mixed) @@ -2815,7 +2863,6 @@ def round(self, k, m, kappa=None, nearest=False, signed=False): :param kappa: statistical security parameter (int) :param nearest: bool :param signed: bool """ - kappa = kappa or program.security secret = isinstance(m, sint) if nearest: if secret: @@ -2830,6 +2877,20 @@ def round(self, k, m, kappa=None, nearest=False, signed=False): def Norm(self, k, f, kappa=None, simplex_flag=False): return library.Norm(self, k, f, kappa, simplex_flag) + def __truediv__(self, other): + """ Secret fixed-point division. + + :param other: any compatible type """ + if isinstance(other, sint): + return other.__rtruediv__(self) + try: + return sfix._new(self) / cfix._new(cint(other), f=sfix.f, k=sfix.k) + except: + return NotImplemented + + def __rtruediv__(self, other): + return sfix._new(other) / sfix._new(self) + @vectorize def int_div(self, other, bit_length=None, security=None): """ Secret integer division. Note that the domain bit length @@ -2839,7 +2900,7 @@ def int_div(self, other, bit_length=None, security=None): :param bit_length: bit length of input (default: global bit length) """ k = bit_length or program.bit_length - kappa = security or program.security + kappa = security tmp = library.IntDiv(self, other, k, kappa) res = type(self)() comparison.Trunc(res, tmp, 2 * k, k, kappa, True) @@ -2963,6 +3024,7 @@ def get_secure_shuffle(n): gensecshuffle(res, n) return res + @read_mem_value def secure_permute(self, shuffle, unit_size=1, reverse=False): res = sint(size=self.size) applyshuffle(res, self, unit_size, shuffle, reverse) @@ -3005,6 +3067,21 @@ def _expand_to_vector(self, size): def copy_from_part(self, source, base, size): picks(self, source, base, 1) + def get_reverse_vector(self): + res = type(self)(size=self.size) + picks(res, self, self.size - 1, -1) + return res + + def get_vector(self, base=0, size=None): + if size is None: + size = len(self) - base + if base == 0 and size == len(self): + return self + assert base + size <= len(self) + res = type(self)(size=size) + picks(res, self, base, 1) + return res + @classmethod def concat(cls, parts): parts = list(parts) @@ -3013,6 +3090,10 @@ def concat(cls, parts): concats(res, *args) return res + @vectorize + def output(self): + print_reg_plains(self) + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -3137,6 +3218,7 @@ def store_in_mem(self, address): """ Store in memory by public address. """ self._store_in_mem(address, gstms, gstmsi) + @vectorize_init def __init__(self, val=None, size=None): super(sgf2n, self).__init__('sg', val=val, size=size) @@ -3144,6 +3226,9 @@ def __neg__(self): """ Identity. """ return self + __truediv__ = _secret.field_div + __rtruediv__ = _secret._rfield_div + @vectorize def __invert__(self): """ Secret bit-wise inversion. """ @@ -3637,6 +3722,7 @@ def load_int(self, other): raise CompilerError('Invalid signed %d-bit integer: %d' % \ (self.n_bits, other)) + @vectorize def load_other(self, other): if isinstance(other, sgf2nint): gmovs(self, self.compose(other.bit_decompose(self.n_bits))) @@ -4200,6 +4286,15 @@ def write_shares_to_socket(cls, client_id, values, cls.int_type.write_shares_to_socket( client_id, [x.v for x in values], message_type) + @vectorized_classmethod + def read_from_socket(cls, client_id, n=1): + return util.untuplify([cls._new(x) for x in util.tuplify( + cls.int_type.read_from_socket(client_id, n))]) + + @classmethod + def write_to_socket(cls, client_id, values): + cls.int_type.write_to_socket(client_id, [x.v for x in values]) + @vectorized_classmethod def load_mem(cls, address, mem_type=None): """ Load from memory by public address. """ @@ -4273,7 +4368,8 @@ def matrix_mul(cls, A, B, n, res_params=None): @classmethod def read_from_file(cls, *args, **kwargs): """ Read shares from ``Persistence/Transactions-P.data``. - Precision must be the same as when storing. + Precision must be the same as when storing. See :ref:`this + section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) @@ -4288,7 +4384,8 @@ def read_from_file(cls, *args, **kwargs): @classmethod def write_to_file(cls, shares, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data``. + ``Persistence/Transactions-P.data``. See :ref:`this + section ` for details on the data format. :param shares: (list or iterable of sfix) :param position: start position (int/regint/cint), @@ -4588,7 +4685,8 @@ def __truediv__(self, other): nearest=self.round_nearest) elif isinstance(other, cfix): v = library.sint_cint_division(self.v, other.v, self.k, self.f, - self.kappa) + self.kappa, + nearest=self.round_nearest) else: raise TypeError('Incompatible fixed point types in division') return self._new(v, k=self.k, f=self.f) @@ -4656,7 +4754,7 @@ class sfix(_fix): default_type = sint @vectorized_classmethod - def get_input_from(cls, player, binary=False): + def get_input_from(cls, player, binary=False, n_bytes=None): """ Secret fixed-point input. :param player: public (regint/cint/int) @@ -4770,6 +4868,13 @@ def unreduced(self, v, other=None, res_params=None, n_summands=1): def multipliable(v, k, f, size): return cfix._new(cint.conv(v, size=size), k, f) + def dot(self, other): + """ Dot product with :py:class:`sint:`. """ + if isinstance(other, sint): + return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f) + else: + raise NotImplementedError() + def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -4793,6 +4898,12 @@ def prefix_sum(self): def sum(self): return self._new(self.v.sum()) + def get_reverse_vector(self): + return self._new(self.v.get_reverse_vector(), k=self.k, f=self.f) + + def get_vector(self, *args, **kwargs): + return self._new(self.v.get_vector(*args, **kwargs), k=self.k, f=self.f) + @classmethod def concat(cls, parts): parts = list(parts) @@ -5486,6 +5597,11 @@ def update(self, other): self.z.update(other.z) self.s.update(other.s) + def for_mux(self, other): + other = self.coerce(other) + f = lambda x: type(self)(*x) + return f, sint(list(self)), sint(list(other)) + class cfloat(Tape._no_truth): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] @@ -5763,6 +5879,9 @@ def _store(self, value, address): tmp.store_in_mem(address) def __len__(self): + if self.length is None: + raise CompilerError('this functionality is not available ' + 'for variable-length arrays') return self.length def total_size(self): @@ -5887,6 +6006,19 @@ def assign_slice_vector(self, slice, vector): addresses = self.get_slice_addresses(slice) vector.store_in_mem(addresses) + def permute(self, permutation, reverse=False, n_threads=None): + """ Public permutation. + + :param permutation: cleartext :py:class`Array` containing number + in :math:`[0,n-1]` where :math:`n` is the length of this array + :param reverse: whether to apply the inverse of the permutation + + """ + if reverse: + self.assign_slice_vector(permutation, self.get_vector()) + else: + self.assign_vector(self.get_slice_vector(permutation)) + def expand_to_vector(self, index, size): """ Create vector from single entry. @@ -5930,7 +6062,8 @@ def _(i): def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. - Precision must be the same as when storing if applicable. + Precision must be the same as when storing if applicable. See + :ref:`this section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) @@ -5943,13 +6076,36 @@ def read_from_file(self, start): def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data``. + ``Persistence/Transactions-P.data``. See :ref:`this + section ` for details on the data format. :param position: start position (int/regint/cint), defaults to end of file """ self.value_type.write_to_file(list(self), position) + def read_from_socket(self, socket, debug=False): + """ Read content from socket. """ + if debug: + library.print_str('reading %s...' % self) + # hard-coded budget for interopability + @library.multithread(None, len(self), max_size=10 ** 6) + def _(base, size): + self.assign_vector( + self.value_type.read_from_socket(socket, size=size), base=base) + if debug: + library.print_ln('done') + + def write_to_socket(self, socket, debug=False): + """ Write content to socket. """ + if debug: + library.print_ln('writing %s' % self) + # hard-coded budget for interopability + @library.multithread(None, len(self), max_size=10 ** 6) + def _(base, size): + self.value_type.write_to_socket( + socket, [self.get_vector(base=base, size=size)]) + def __add__(self, other): """ Vector addition. @@ -6254,9 +6410,13 @@ def assign(self, other): """ Assign container to content. Not implemented for floating-point. :param other: container of matching size and type """ - if self.value_type.n_elements() > 1: - assert self.sizes == other.sizes - self.assign_vector(other.get_vector()) + try: + if self.value_type.n_elements() > 1: + assert self.sizes == other.sizes + self.assign_vector(other.get_vector()) + except: + for i, x in enumerate(other): + self[i].assign(x) def get_part_vector(self, base=0, size=None): """ Vector from range of the first dimension, including all @@ -6297,15 +6457,41 @@ def assign_slice_vector(self, slice, vector): addresses = self.get_slice_addresses(slice) vector.store_in_mem(self.address + addresses) - def get_slice_addresses(self, slice): + def get_part_size(self): assert self.value_type.n_elements() == 1 - part_size = reduce(operator.mul, self.sizes[1:]) + return reduce(operator.mul, self.sizes[1:]) + + def get_slice_addresses(self, slice, part_size=None): + part_size = part_size or self.get_part_size() assert len(slice) * part_size <= self.total_size() base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) addresses = slice.value_type.load_mem(base) * part_size + inc return addresses + def permute(self, permutation, reverse=False, n_threads=None): + """ Public permutation along first dimension. + + :param permutation: cleartext :py:class`Array` containing number + in :math:`[0,n-1]` where :math:`n` is the length of this array + :param reverse: whether to apply the inverse of the permutation + + """ + @library.multithread(n_threads, self.get_part_size()) + def _(base, size): + addresses = self.get_slice_addresses(permutation, part_size=1) + addresses *= self.get_part_size() + @library.for_range_opt(size) + def _(j): + i = base + j + if reverse: + v = self.get_column(i) + v.store_in_mem(self.address + i + addresses) + else: + v = self.value_type.load_mem( + self.address + i + addresses) + self.set_column(i, v) + def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 assert len(indices) == len(self.sizes) @@ -6389,7 +6575,8 @@ def _(i): def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data``. + ``Persistence/Transactions-P.data``. See :ref:`this + section ` for details on the data format. :param position: start position (int/regint/cint), defaults to end of file @@ -6404,7 +6591,8 @@ def _(i): def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. - Precision must be the same as when storing if applicable. + Precision must be the same as when storing if applicable. See + :ref:`this section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) @@ -6417,6 +6605,14 @@ def _(i): start.write(self[i].read_from_file(start)) return start + def write_to_socket(self, socket, debug=False): + """ Write content to socket. """ + self.array.write_to_socket(socket, debug=debug) + + def read_from_socket(self, socket, debug=False): + """ Read content from socket. """ + self.array.read_from_socket(socket, debug=debug) + def schur(self, other): """ Element-wise product. @@ -6758,7 +6954,28 @@ def parallel_mul(self, other): res = self.value_type.dot_product(a, b) return res - def transpose(self): + def get_column(self, index): + """ Get matrix column as vector. + + :param index: regint/cint/int + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.get_part_size()) + return self.value_type.load_mem(addresses) + + def set_column(self, index, vector): + """ Change column. + + :param index: regint/cint/int + :param vector: short enought vector of compatible type + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.get_part_size()) + self.value_type.conv(vector).store_in_mem(addresses) + + def transpose(self, n_threads=None): """ Matrix transpose. :param self: two-dimensional """ @@ -6766,13 +6983,24 @@ def transpose(self): res = Matrix(self.sizes[1], self.sizes[0], self.value_type) library.break_point() if self.value_type.n_elements() == 1: - nr = self.sizes[1] - nc = self.sizes[0] - a = regint.inc(nr * nc, 0, nr, 1, nc) - b = regint.inc(nr * nc, 0, 1, nc) - res[:] = self.value_type.load_mem(self.address + a + b) + if self.sizes[0] < program.budget: + if self.sizes[1] < program.budget: + nr = self.sizes[1] + nc = self.sizes[0] + a = regint.inc(nr * nc, 0, nr, 1, nc) + b = regint.inc(nr * nc, 0, 1, nc) + res[:] = self.value_type.load_mem(self.address + a + b) + else: + @library.for_range_multithread(n_threads, 1, self.sizes[0]) + def _(i): + res.set_column(i, self[i][:]) + else: + @library.for_range_multithread(n_threads, 1, self.sizes[1]) + def _(i): + res[i][:] = self.get_column(i) else: - @library.for_range_opt(self.sizes[1], budget=100) + @library.for_range_opt_multithread(n_threads, self.sizes[1], + budget=100) def _(i): @library.for_range_opt(self.sizes[0], budget=100) def _(j): @@ -6801,7 +7029,7 @@ def secure_shuffle(self): """ self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) - def secure_permute(self, permutation, reverse=False): + def secure_permute(self, permutation, reverse=False, n_threads=None): """ Securely permute rows (first index). See :py:func:`secure_shuffle` for references. @@ -6809,8 +7037,12 @@ def secure_permute(self, permutation, reverse=False): :param reverse: whether to apply inverse (default: False) """ - self.assign_vector(self.get_vector().secure_permute( - permutation, self.part_size(), reverse)) + if n_threads is not None: + permutation = MemValue(permutation) + @library.for_range_multithread(n_threads, 1, self.get_part_size()) + def _(i): + self.set_column(i, self.get_column(i).secure_permute( + permutation, reverse=reverse)) def sort(self, key_indices=None, n_bits=None): """ Sort sub-arrays (different first index) in place. @@ -6829,6 +7061,9 @@ def sort(self, key_indices=None, n_bits=None): return if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) + if len(key_indices) != len(self.sizes) - 1: + raise CompilerError('length of key_indices has to be one less ' + 'than the dimension') key_indices = (None,) + util.tuplify(key_indices) from . import sorting keys = self.get_vector_by_indices(*key_indices) @@ -6971,7 +7206,8 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): @staticmethod def create_from(rows): - rows = list(rows) + if not isinstance(rows, _vectorizable): + rows = list(rows) if isinstance(rows[0], (list, tuple, Array)): t = type(rows[0][0]) else: @@ -6983,20 +7219,15 @@ def create_from(rows): raise CompilerError( 'accidental shortening by creating matrix') res = Matrix(len(rows), len(rows[0]), t) - for i in range(len(rows)): - res[i].assign(rows[i]) + if isinstance(rows, _vectorizable): + @library.for_range_opt(len(rows)) + def _(i): + res[i].assign(rows[i]) + else: + for i in range(len(rows)): + res[i].assign(rows[i]) return res - def get_column(self, index): - """ Get column as vector. - - :param index: regint/cint/int - """ - assert self.value_type.n_elements() == 1 - addresses = regint.inc(self.sizes[0], self.address + index, - self.sizes[1]) - return self.value_type.load_mem(addresses) - def get_columns(self): return (self.get_column(i) for i in range(self.sizes[1])) @@ -7006,17 +7237,6 @@ def get_column_by_row_indices(self, rows, column): regint.inc(len(rows), self.address + column, 0) return self.value_type.load_mem(addresses) - def set_column(self, index, vector): - """ Change column. - - :param index: regint/cint/int - :param vector: short enought vector of compatible type - """ - assert self.value_type.n_elements() == 1 - addresses = regint.inc(self.sizes[0], self.address + index, - self.sizes[1]) - self.value_type.conv(vector).store_in_mem(addresses) - def concat_columns(self, other): """ Concatenate two matrices by columns. """ assert self.sizes[0] == other.sizes[0] @@ -7217,6 +7437,8 @@ def reveal(self): bit_and = lambda self,other: self.read().bit_and(other) bit_not = lambda self: self.read().bit_not() + print_if = lambda self,*args,**kwargs: self.read().print_if(*args, **kwargs) + def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: return self.read().expand_to_vector(size) diff --git a/Dockerfile b/Dockerfile index 760d19349..e6b361f99 100644 --- a/Dockerfile +++ b/Dockerfile @@ -64,14 +64,14 @@ RUN pip install --upgrade pip ipython COPY . . -ARG arch=native +ARG arch= ARG cxx=clang++-11 ARG use_ntl=0 ARG prep_dir="Player-Data" ARG ssl_dir="Player-Data" -RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \ - && echo "CXX = ${cxx}" >> CONFIG.mine \ +RUN if test -n "${arch}"; then echo "ARCH = -march=${arch}" >> CONFIG.mine; fi +RUN echo "CXX = ${cxx}" >> CONFIG.mine \ && echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \ && echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \ && echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \ diff --git a/ECDSA/CurveElement.h b/ECDSA/CurveElement.h index 254271e1b..dc0c8f08b 100644 --- a/ECDSA/CurveElement.h +++ b/ECDSA/CurveElement.h @@ -50,7 +50,6 @@ class CurveElement : public ValueInterface void assign_zero() { *this = 0; } bool is_zero() { return *this == 0; } - void add(octetStream& os) { *this += os.get(); } void pack(octetStream& os) const; void unpack(octetStream& os); diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 1ff3273f8..1059506b3 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -166,13 +166,3 @@ bool P256Element::operator !=(const P256Element& other) const { return not (*this == other); } - -octetStream P256Element::hash(size_t n_bytes) const -{ - octetStream os; - pack(os); - auto res = os.hash(); - assert(n_bytes >= res.get_length()); - res.resize_precise(n_bytes); - return res; -} diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 534c2a997..9a6063609 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -56,15 +56,9 @@ class P256Element : public ValueInterface bool operator==(const P256Element& other) const; bool operator!=(const P256Element& other) const; - void assign_zero() { *this = {}; } - bool is_zero() { return *this == P256Element(); } - void add(octetStream& os, int = -1) { *this += os.get(); } - void pack(octetStream& os, int = -1) const; void unpack(octetStream& os, int = -1); - octetStream hash(size_t n_bytes) const; - friend ostream& operator<<(ostream& s, const P256Element& x); }; diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index b040dd5e8..daef496b7 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -131,6 +131,12 @@ int main(int argc, char** argv) case 'R': { int R = specification.get(); + int R2 = specification.get(); + if (R2 != 64) + { + cerr << R2 << "-bit ring not implemented" << endl; + } + switch (R) { case 64: diff --git a/ExternalIO/bankers-bonus-client.py b/ExternalIO/bankers-bonus-client.py index bd9665aff..71aa13f41 100755 --- a/ExternalIO/bankers-bonus-client.py +++ b/ExternalIO/bankers-bonus-client.py @@ -14,24 +14,17 @@ client = Client(['localhost'] * n_parties, 14000, client_id) -type = client.specification.get_int(4) - -if type == ord('R'): - domain = Z2(client.specification.get_int(4)) -elif type == ord('p'): - domain = Fp(client.specification.get_bigint()) -else: - raise Exception('invalid type') - for socket in client.sockets: os = octetStream() os.store(finish) os.Send(socket) +def run(x): + client.send_private_inputs([x]) + + print('Winning client id is :', client.receive_outputs(1)[0]) + # running two rounds # first for sint, then for sfix -for x in bonus, bonus * 2 ** 16: - client.send_private_inputs([domain(x)]) - - print('Winning client id is :', - int(client.receive_outputs(domain, 1)[0])) +run(bonus) +run(bonus * 2 ** 16) diff --git a/ExternalIO/client.py b/ExternalIO/client.py index f68813cbc..6d560f465 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -2,6 +2,7 @@ import socket, ssl import struct import time +from domains import * # The following function is either taken directly or derived from: # https://stackoverflow.com/questions/12248132/how-to-change-tcp-keepalive-timer-using-python-script @@ -61,6 +62,15 @@ def __init__(self, hostnames, port_base, my_client_id): self.specification = octetStream() self.specification.Receive(self.sockets[0]) + type = self.specification.get_int(4) + if type == ord('R'): + self.domain = Z2(self.specification.get_int(4)) + self.clear_domain = Z2(self.specification.get_int(4)) + elif type == ord('p'): + self.domain = Fp(self.specification.get_bigint()) + self.clear_domain = self.domain + else: + raise Exception('invalid type') def receive_triples(self, T, n): triples = [[0, 0, 0] for i in range(n)] @@ -89,18 +99,19 @@ def receive_triples(self, T, n): return triples def send_private_inputs(self, values): - T = type(values[0]) + T = self.domain triples = self.receive_triples(T, len(values)) os = octetStream() assert len(values) == len(triples) for value, triple in zip(values, triples): - (value + triple[0]).pack(os) + (T(value) + triple[0]).pack(os) for socket in self.sockets: os.Send(socket) - def receive_outputs(self, T, n): + def receive_outputs(self, n): + T = self.domain triples = self.receive_triples(T, n) - return [triple[0] for triple in triples] + return [int(self.clear_domain(triple[0].v)) for triple in triples] class octetStream: def __init__(self, value=None): diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index bb22c11d2..4da39740d 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -67,20 +67,15 @@ void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1, cc0.Scale(pk.p()); cc1.Scale(pk.p()); // Now do the multiply - Rq_Element d0,d1,d2; - - mul(d0,cc0.cc0,cc1.cc0); - mul(d1,cc0.cc0,cc1.cc1); - mul(d2,cc0.cc1,cc1.cc0); - add(d1,d1,d2); - mul(d2,cc0.cc1,cc1.cc1); + auto d0 = cc0.cc0 * cc1.cc0; + auto d1 = cc0.cc0 * cc1.cc1 + cc0.cc1 * cc1.cc0; + auto d2 = cc0.cc1 * cc1.cc1; d2.negate(); // Now do the switch key d2.raise_level(); - Rq_Element t; d0.mul_by_p1(); - mul(t,pk.bs(),d2); + auto t = pk.bs()* d2; add(d0,d0,t); d1.mul_by_p1(); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index eb47a4208..cc68efc34 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -29,7 +29,6 @@ class Ciphertext word pk_id; public: - static int size() { return 0; } const FHE_Params& get_params() const { return *params; } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 08870dc75..3329d7363 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -94,8 +94,8 @@ void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, add(PK.Sw_b,PK.Sw_b,es); // bs=bs-p1*s^2 - Rq_Element s2; - mul(s2,sk,sk); // Mult at level 0 + // Mult at level 0 + auto s2 = sk * sk; s2.mul_by_p1(); // This raises back to level 1 sub(PK.Sw_b,PK.Sw_b,s2); } @@ -155,17 +155,12 @@ void FHE_PK::quasi_encrypt(Ciphertext& c, if (&rc.get_params()!=params) { throw params_mismatch(); } assert(pr != 0); - Rq_Element ed,edd,c0,c1,aa; - // c1=a0*u+p*v - mul(aa,a0,rc.u()); - mul(ed,rc.v(),pr); - add(c1,aa,ed); + auto c1 = a0 * rc.u() + rc.v() * pr; // c0 = b0 * u + p * w + mess - mul(c0,b0,rc.u()); - mul(edd,rc.w(),pr); - add(edd,edd,mess); + auto c0 = b0 * rc.u(); + auto edd = rc.w() * pr + mess; if (params->n_mults() == 0) edd.change_rep(evaluation); else @@ -218,10 +213,7 @@ Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const { if (&c.get_params()!=params) { throw params_mismatch(); } - Rq_Element ans; - - mul(ans,c.c1(),sk); - sub(ans,c.c0(),ans); + auto ans = c.c0() - c.c1() * sk; ans.change_rep(polynomial); return ans; } @@ -267,8 +259,7 @@ void FHE_SK::dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_ Ciphertext cc=ctx; cc.Scale(pr); // First do the basic decryption - Rq_Element dec_sh; - mul(dec_sh,cc.c1(),sk); + auto dec_sh = cc.c1() * sk; if (player_number==0) { sub(dec_sh,cc.c0(),dec_sh); } else diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index b0e88e970..1c4fc8f3d 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -89,9 +89,6 @@ class FHE_SK bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; } - void add(octetStream& os, int = -1) - { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; } - void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const; template @@ -120,10 +117,12 @@ class FHE_PK bigint p() const { return pr; } void assign(const Rq_Element& a,const Rq_Element& b, - const Rq_Element& sa = {},const Rq_Element& sb = {} + const Rq_Element& sa,const Rq_Element& sb ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } + void assign(const Rq_Element& a,const Rq_Element& b) + { a0=a; b0=b; } FHE_PK(const FHE_Params& pms); diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 04698ade8..52fa9e1d3 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -27,6 +27,8 @@ class RingReadIterator; class Ring_Element { + friend class Rq_Element; + RepType rep; /* FFTD is defined as a pointer so each different Ring_Element @@ -41,6 +43,9 @@ class Ring_Element vector element; + /* Careful calling this one, as FFTD will not be defined */ + Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; } + public: // Used to basically make sure *this is able to cope @@ -63,9 +68,6 @@ class Ring_Element void assign_zero(); void assign_one(); - /* Careful calling this one, as FFTD will not be defined */ - Ring_Element(RepType r=polynomial) : FFTD(0) { rep=r; } - Ring_Element(const FFT_Data& prd,RepType r=polynomial); template diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index 97977fdba..e3292a0c4 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -23,7 +23,7 @@ Rq_Element::Rq_Element(const vector& prd, RepType r0, RepType r1) void Rq_Element::set_data(const vector& prd) { - a.resize(prd.size()); + a.resize(prd.size(), {}); for(size_t i = 0; i < a.size(); i++) a[i].set_data(prd[i]); lev=n_mults(); @@ -50,7 +50,7 @@ void Rq_Element::assign_one() void Rq_Element::partial_assign(const Rq_Element& other) { lev=other.lev; - a.resize(other.a.size()); + a.resize(other.a.size(), {}); } void Rq_Element::negate() @@ -112,13 +112,6 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } -void Rq_Element::add(octetStream& os, int) -{ - Rq_Element tmp(*this); - tmp.unpack(os); - *this += tmp; -} - void Rq_Element::randomize(PRNG& G,int l) { set_level(l); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index cdf3626c7..6db65a2dd 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -33,6 +33,10 @@ class Rq_Element vector a; int lev; + // Must be careful not to call by mistake + Rq_Element(RepType r0=evaluation,RepType r1=polynomial) : + a({r0, r1}), lev(n_mults()) {} + public: int n_mults() const { return a.size() - 1; } @@ -46,10 +50,6 @@ class Rq_Element void assign_one(); void partial_assign(const Rq_Element& e); - // Must be careful not to call by mistake - Rq_Element(RepType r0=evaluation,RepType r1=polynomial) : - a({r0, r1}), lev(n_mults()) {} - // Pass in a pair of FFT_Data as a vector Rq_Element(const vector& prd, RepType r0 = evaluation, RepType r1 = polynomial); @@ -97,8 +97,6 @@ class Rq_Element friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); - void add(octetStream& os, int = -1); - template Rq_Element& operator+=(const vector& other); diff --git a/FHEOffline/DistKeyGen.cpp b/FHEOffline/DistKeyGen.cpp index 482a87e59..c51c13d33 100644 --- a/FHEOffline/DistKeyGen.cpp +++ b/FHEOffline/DistKeyGen.cpp @@ -15,13 +15,10 @@ void Encrypt_Rq_Element(Ciphertext& c,const Rq_Element& mess, const Random_Coins& rc, const FHE_PK& pk) { - Rq_Element ed, edd, c0, c1; - mul(c1, pk.a(), rc.u()); - mul(ed, rc.v(), pk.p()); - add(c1, c1, ed); + auto c1 = pk.a() * rc.u() + rc.v() * pk.p(); + auto c0 = pk.b() * rc.u(); + auto edd = rc.w() * pk.p(); - mul(c0, pk.b(), rc.u()); - mul(edd, rc.w(), pk.p()); edd.change_rep(evaluation,evaluation); add(edd,edd,mess); add(c0,c0,edd); diff --git a/GC/Processor.h b/GC/Processor.h index e21cf6007..791552a09 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -87,6 +87,7 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void xorc(const ::BaseInstruction& instruction); void nots(const ::BaseInstruction& instruction); void notcb(const ::BaseInstruction& instruction); + void movsb(const ::BaseInstruction& instruction); void andm(const ::BaseInstruction& instruction); void and_(const vector& args, bool repeat); void andrs(const vector& args) { and_(args, true); } diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 22b484d24..084f605b2 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -283,6 +283,13 @@ void Processor::notcb(const ::BaseInstruction& instruction) } } +template +void Processor::movsb(const ::BaseInstruction& instruction) +{ + for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++) + S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i]; +} + template void Processor::andm(const ::BaseInstruction& instruction) { diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 1fe06c486..bb1930452 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -32,7 +32,7 @@ void TinierSharePrep::buffer_secret_triples() assert(triple_generator != 0); params.generateBits = false; vector> triples; - TripleShuffleSacrifice sacrifice; + TripleShuffleSacrifice sacrifice(DATA_GF2); size_t required; required = sacrifice.minimum_n_inputs_with_combining( BaseMachine::batch_size(DATA_TRIPLE)); diff --git a/GC/instructions.h b/GC/instructions.h index 67ea461a1..488c1924d 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -65,7 +65,7 @@ X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \ X(LDMCBI, PROC.mem_op(SIZE, PROC.C, MMC, R0, Ci[REG1])) \ X(STMCBI, PROC.mem_op(SIZE, MMC, PROC.C, Ci[REG1], R0)) \ - X(MOVSB, S0 = PS1) \ + X(MOVSB, PROC.movsb(INST)) \ X(TRANS, T::trans(PROC, IMM, EXTRA)) \ X(BITB, PROC.random_bit(S0)) \ X(REVEAL, T::reveal_inst(PROC, EXTRA)) \ @@ -123,7 +123,7 @@ X(LDMINTI, I0 = MII) \ X(STMINTI, MII = I0) \ X(PUSHINT, PROC.pushi(I0.get())) \ - X(POPINT, long x; PROC.popi(x); I0 = x) \ + X(POPINT, PROC.popi(I0)) \ X(MOVINT, I0 = PI1) \ X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \ X(LDARG, I0 = PROC.get_arg()) \ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 3d8c9a414..a4f22206e 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -72,6 +72,12 @@ ShamirOptions::ShamirOptions(ez::ezOptionParser& opt, int argc, const char** arg ); opt.parse(argc, argv); opt.get("-N")->getInt(nparties); + if (nparties < 3) + { + cerr << "Protocols based on Shamir secret sharing require at least " + << "three parties." << endl; + exit(1); + } set_threshold(opt); opt.resetArgs(); } diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 469116953..017ce3ddd 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -26,23 +26,9 @@ int main(int argc, const char** argv) ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); online_opts = {opt, argc, argv, FakeShare>()}; - opt.parse(argc, argv); opt.syntax = string(argv[0]) + " "; - - string progname; - if (opt.firstArgs.size() > 1) - progname = *opt.firstArgs.at(1); - else if (not opt.lastArgs.empty()) - progname = *opt.lastArgs.at(0); - else if (not opt.unknownArgs.empty()) - progname = *opt.unknownArgs.at(0); - else - { - string usage; - opt.getUsage(usage); - cerr << usage << endl; - exit(1); - } + online_opts.finalize(opt, argc, argv, false); + string& progname = online_opts.progname; #ifdef ROUND_NEAREST_IN_EMULATION cerr << "Using nearest rounding instead of probabilistic truncation" << endl; diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index ceb35b089..47d647249 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -7,6 +7,7 @@ #include "Processor/OnlineMachine.hpp" #include "Processor/Machine.hpp" +#include "Processor/OnlineOptions.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/ReplicatedPrep.hpp" @@ -17,7 +18,7 @@ int main(int argc, const char** argv) { ez::ezOptionParser opt; - OnlineOptions::singleton = {opt, argc, argv}; + OnlineOptions::singleton = {opt, argc, argv, NoShare()}; OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton); OnlineOptions::singleton.finalize(opt, argc, argv); machine.start_networking(); diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index ed2d53e2f..b02ff73ad 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -7,6 +7,7 @@ #include "Processor/Machine.h" #include "Processor/RingOptions.h" #include "Protocols/Spdz2kShare.h" +#include "Protocols/SPDZ2k.h" #include "Math/gf2n.h" #include "Networking/Server.h" @@ -62,8 +63,10 @@ int main(int argc, const char** argv) cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line " << (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+" << s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl; - cerr << "Alternatively, compile with -DRING_SIZE=" << k - << " and -DSPDZ2K_DEFAULT_SECURITY=" << s << endl; + cerr << "Alternatively, put 'MY_CFLAGS += -DRING_SIZE=" << k + << " -DSPDZ2K_DEFAULT_SECURITY=" << s + << "' in 'CONFIG.mine' before running 'make spdz2k-party.x'" + << endl; } exit(1); } diff --git a/Makefile b/Makefile index d4a8dedc4..50a634526 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x ifeq ($(OS), Darwin) setup: mac-setup else -setup: boost linux-machine-setup +setup: maybe-boost linux-machine-setup endif tldr: setup @@ -296,13 +296,17 @@ deps/SimplestOT_C/ref10/Makefile: .PHONY: Programs/Circuits Programs/Circuits: - git submodule update --init Programs/Circuits + git submodule update --init Programs/Circuits || git clone https://github.com/mkskeller/bristol-fashion Programs/Circuits deps/libOTe/libOTe: git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe boost: deps/libOTe/libOTe cd deps/libOTe; \ python3 build.py --setup --boost --install=$(CURDIR)/local +maybe-boost: deps/libOTe/libOTe + cd `mktemp -d`; \ + PATH="$(CURDIR)/local/bin:$(PATH)" cmake $(CURDIR)/deps/libOTe || \ + { cd -; make boost; } OTE_OPTS += -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib @@ -334,11 +338,12 @@ OT/OTExtensionWithMatrix.o: $(OTE) endif local/lib/liblibOTe.a: deps/libOTe/libOTe + make maybe-boost; \ cd deps/libOTe; \ PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS) && \ touch ../../local/lib/liblibOTe.a -$(SHARED_OTE): deps/libOTe/libOTe +$(SHARED_OTE): deps/libOTe/libOTe maybe-boost cd deps/libOTe; \ python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS) diff --git a/Math/Bit.h b/Math/Bit.h index b5d565102..10c4e0187 100644 --- a/Math/Bit.h +++ b/Math/Bit.h @@ -51,11 +51,6 @@ class Bit : public BitVec_ return other * *this; } - void add(octetStream& os, int = -1) - { - *this += os.get(); - } - void pack(octetStream& os, int = -1) const { super::pack(os, 1); diff --git a/Math/BitVec.h b/Math/BitVec.h index a362d0101..ca63c24cb 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -56,13 +56,6 @@ class BitVec_ : public IntBase void extend_bit(BitVec_& res, int) const { res = extend_bit(); } - void add(octetStream& os, int n_bits) - { - BitVec_ tmp; - tmp.unpack(os, n_bits); - *this += tmp; - } - void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; } void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = this->mask(n); } diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 489ec5ae9..a412c7e04 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -138,12 +138,6 @@ class FixedVec v[i] = (x.v[i] * y.v[i]); } - void add(octetStream& os) - { - for (int i = 0; i < L; i++) - v[i].add(os); - } - void negate() { for (auto& x : v) diff --git a/Math/ValueInterface.cpp b/Math/ValueInterface.cpp index 68758fb08..ad7036b46 100644 --- a/Math/ValueInterface.cpp +++ b/Math/ValueInterface.cpp @@ -8,6 +8,8 @@ #include +const false_type ValueInterface::binary; + void ValueInterface::check_setup(const string& directory) { struct stat sb; diff --git a/Math/Z2k.h b/Math/Z2k.h index 924aa9536..2c6704d34 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -156,8 +156,6 @@ class Z2 : public ValueInterface bool operator==(const Z2& other) const; bool operator!=(const Z2& other) const { return not (*this == other); } - void add(octetStream& os, int = -1) { *this += (os.consume(size())); } - Z2 lazy_add(const Z2& x) const; Z2 lazy_mul(const Z2& x) const; diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 8677df404..6dc8b3130 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -136,7 +136,7 @@ inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y template<> inline void Zp_Data::Add<1>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const { -#if defined(__clang__) || !defined(__x86_64__) +#if defined(__clang__) || !defined(__x86_64__) || (__GNUC__ == 10) Add<0>(ans, x, y); #else *ans = *x + *y; diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 1c01c53aa..aef081a3a 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -87,12 +87,6 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs) mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data); } -void bigint::add(octetStream& os, int) -{ - tmp.unpack(os); - *this += tmp; -} - string to_string(const bigint& x) { stringstream ss; diff --git a/Math/bigint.h b/Math/bigint.h index 41da70f1e..795bcbff3 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -102,8 +102,6 @@ class bigint : public mpz_class void mul(const bigint& x, const bigint& y) { *this = x * y; } - void add(octetStream& os, int = -1); - #ifdef REALLOC_POLICE ~bigint() { lottery(); } void lottery(); diff --git a/Math/gf2n.h b/Math/gf2n.h index bed5ba724..a19ca6a0d 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -136,10 +136,6 @@ class gf2n_ : public ValueInterface // x+y void add(const gf2n_& x,const gf2n_& y) { a=x.a^y.a; } - void add(octet* x) - { a^=*(U*)(x); } - void add(octetStream& os, int = -1) - { add(os.consume(size())); } void sub(const gf2n_& x,const gf2n_& y) { a=x.a^y.a; } // = x * y diff --git a/Math/gfp.h b/Math/gfp.h index 31f3a571c..43bfb5424 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -189,12 +189,8 @@ class gfp_ : public ValueInterface bool operator!=(const gfp_& y) const { return !equal(y); } // x+y - void add(octetStream& os, int = -1) - { add(os.consume(size())); } void add(const gfp_& x,const gfp_& y) { ZpD.Add(a.x,x.a.x,y.a.x); } - void add(void* x) - { ZpD.Add(a.x,a.x,(mp_limb_t*)x); } void sub(const gfp_& x,const gfp_& y) { ZpD.Sub(a.x,x.a.x,y.a.x); } // = x * y diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 383d45751..39216ef45 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -295,12 +295,6 @@ bool gfpvar_::operator !=(const gfpvar_& other) const return not (*this == other); } -template -void gfpvar_::add(octetStream& other, int) -{ - *this += other.get>(); -} - template void gfpvar_::negate() { diff --git a/Math/gfpvar.h b/Math/gfpvar.h index ceb4e9ed3..e2cafb365 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -149,8 +149,6 @@ class gfpvar_ bool operator==(const gfpvar_& other) const; bool operator!=(const gfpvar_& other) const; - void add(octetStream& other, int = -1); - void negate(); gfpvar_ invert() const; diff --git a/Networking/Player.cpp b/Networking/Player.cpp index fc9350e68..d70e5d639 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -272,14 +272,14 @@ PlayerBase::~PlayerBase() // Set up nmachines client and server sockets to send data back and fro -// A machine is a server between it and player i if i<=my_number +// A machine is a server between it and player i if i>=my_number // Can also communicate with myself, but only with send_to and receive_from void PlainPlayer::setup_sockets(const vector& names, const vector& ports, const string& id_base, ServerSocket& server) { sockets.resize(nplayers); // Set up the client side - for (int i=player_no; i& names, } send_to_self_socket = sockets[player_no]; // Setting up the server side - for (int i=0; i<=player_no; i++) { + for (int i=player_no; i using namespace std; +// default to one minute +#ifndef CONNECTION_TIMEOUT +#define CONNECTION_TIMEOUT 60 +#endif void error(const char *str); @@ -38,10 +42,15 @@ void receive(T& socket, size_t& a, size_t len); inline size_t send_non_blocking(int socket, octet* msg, size_t len) { +#ifdef __APPLE__ + int j = send(socket,msg,min(len,10000lu),MSG_DONTWAIT); +#else int j = send(socket,msg,len,MSG_DONTWAIT); +#endif if (j < 0) { - if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK) + if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK and + errno != ENOBUFS) { error("Send error - 1 "); } else return 0; diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index b996d83b6..2925a3508 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -41,6 +41,8 @@ union square128 { int16_t doublebytes[128][8]; int32_t words[128][4]; + square128() {} + bool get_bit(int x, int y) { return (bytes[x][y/8] >> (y % 8)) & 1; } diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 105a755f2..2752a5da0 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -8,6 +8,7 @@ #include "Math/Setup.h" #include "Tools/Bundle.h" +#include "Instruction.hpp" #include "Protocols/ShuffleSacrifice.hpp" #include @@ -39,11 +40,26 @@ bool BaseMachine::has_program() int BaseMachine::edabit_bucket_size(int n_bits) { - int res = OnlineOptions::singleton.bucket_size; + size_t usage = 0; + if (has_program()) + usage = s().progs[0].get_offline_data_used().total_edabits(n_bits); + return bucket_size(usage); +} +int BaseMachine::triple_bucket_size(DataFieldType type) +{ + size_t usage = 0; if (has_program()) + usage = s().progs[0].get_offline_data_used().files[type][DATA_TRIPLE]; + return bucket_size(usage); +} + +int BaseMachine::bucket_size(size_t usage) +{ + int res = OnlineOptions::singleton.bucket_size; + + if (usage) { - auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits); for (int B = res; B <= 5; B++) if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9) break; @@ -91,7 +107,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) string threadname; for (int i=0; i> threadname; - size_t split = threadname.find(":"); + size_t split = threadname.find_last_of(":"); long expected = -1; if (split != string::npos) { @@ -125,6 +141,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) getline(inpf, compiler); getline(inpf, domain); getline(inpf, relevant_opts); + getline(inpf, security); inpf.close(); } @@ -184,17 +201,19 @@ string BaseMachine::memory_filename(const string& type_short, int my_number) string BaseMachine::get_domain(string progname) { - if (singleton) - { - assert(s().progname == progname); - return s().domain; - } + return get_basics(progname).domain; +} + +BaseMachine BaseMachine::get_basics(string progname) +{ + if (singleton and s().progname == progname) + return s(); - assert(not singleton); + auto backup = singleton; BaseMachine machine; - singleton = 0; + singleton = backup; machine.load_schedule(progname, false); - return machine.domain; + return machine; } int BaseMachine::ring_size_from_schedule(string progname) @@ -226,6 +245,15 @@ bigint BaseMachine::prime_from_schedule(string progname) return 0; } +int BaseMachine::security_from_schedule(string progname) +{ + string sec = get_basics(progname).security; + if (sec.substr(0, 4).compare("sec:") == 0) + return stoi(sec.substr(4)); + else + return 0; +} + NamedCommStats BaseMachine::total_comm() { NamedCommStats res; diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 7d2da9be3..e522dcfe1 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -32,10 +32,13 @@ class BaseMachine string compiler; string domain; string relevant_opts; + string security; virtual size_t load_program(const string& threadname, const string& filename); + static BaseMachine get_basics(string progname); + public: static thread_local int thread_num; @@ -58,12 +61,15 @@ class BaseMachine static int ring_size_from_schedule(string progname); static int prime_length_from_schedule(string progname); static bigint prime_from_schedule(string progname); + static int security_from_schedule(string progname); template static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0); template static int edabit_batch_size(int n_bits, int buffer_size = 0); static int edabit_bucket_size(int n_bits); + static int triple_bucket_size(DataFieldType type); + static int bucket_size(size_t usage); BaseMachine(); virtual ~BaseMachine() {} diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index 2c8036cda..1bf8136e8 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -1,5 +1,7 @@ #include "Processor/ExternalClients.h" +#include "Processor/OnlineOptions.h" #include "Networking/ServerSocket.h" +#include "Networking/ssl_sockets.h" #include #include #include @@ -25,6 +27,8 @@ ExternalClients::~ExternalClients() } if (ctx) delete ctx; + for (auto it = peer_ctxs.begin(); it != peer_ctxs.end(); it++) + delete it->second; } void ExternalClients::start_listening(int portnum_base) @@ -32,8 +36,9 @@ void ExternalClients::start_listening(int portnum_base) ScopeLock _(lock); client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num()); client_connection_servers[portnum_base]->init(); - cerr << "Start listening on thread " << this_thread::get_id() << endl; - cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num()) + if (OnlineOptions::singleton.verbose) + cerr << "Party " << get_party_num() << " is listening on port " + << (portnum_base + get_party_num()) << " for external client connections." << endl; } @@ -46,7 +51,6 @@ int ExternalClients::get_client_connection(int portnum_base) cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl; throw runtime_error("No connection on port " + to_string(portnum_base)); } - cerr << "Thread " << this_thread::get_id() << " found server." << endl; int client_id, socket; string client; socket = client_connection_servers[portnum_base]->get_connection_socket( @@ -57,10 +61,38 @@ int ExternalClients::get_client_connection(int portnum_base) external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket, "C" + to_string(client_id), "P" + to_string(get_party_num()), false); client_ports[client_id] = portnum_base; - cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; + if (OnlineOptions::singleton.verbose) + cerr << "Party " << get_party_num() + << " received external client connection from client id: " << dec + << client_id << endl; return client_id; } +int ExternalClients::init_client_connection(const string& host, int portnum, + int my_client_id) +{ + ScopeLock _(lock); + int plain_socket; + set_up_client_socket(plain_socket, host.c_str(), portnum); + octetStream(to_string(my_client_id)).Send(plain_socket); + string my_client_name = "C" + to_string(my_client_id); + if (peer_ctxs.find(my_client_id) == peer_ctxs.end()) + peer_ctxs[my_client_id] = new client_ctx(my_client_name); + auto socket = new client_socket(io_service, *peer_ctxs[my_client_id], + plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id), + true); + if (party_num == 0) + { + octetStream specification; + specification.Receive(socket); + } + int id = -1; + if (not external_client_sockets.empty()) + id = min(id, external_client_sockets.begin()->first); + external_client_sockets[id] = socket; + return id; +} + void ExternalClients::close_connection(int client_id) { ScopeLock _(lock); diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h index bada59b40..b9030fd5a 100644 --- a/Processor/ExternalClients.h +++ b/Processor/ExternalClients.h @@ -32,6 +32,7 @@ class ExternalClients ssl_service io_service; client_ctx* ctx; + map peer_ctxs; Lock lock; @@ -43,6 +44,7 @@ class ExternalClients void start_listening(int portnum_base); int get_client_connection(int portnum_base); + int init_client_connection(const string& host, int portnum, int my_client_id); void close_connection(int client_id); diff --git a/Processor/FieldMachine.hpp b/Processor/FieldMachine.hpp index 89ec66e1c..926fefe8c 100644 --- a/Processor/FieldMachine.hpp +++ b/Processor/FieldMachine.hpp @@ -12,7 +12,6 @@ #include "OnlineMachine.hpp" #include "OnlineOptions.hpp" - template class T, class V> HonestMajorityFieldMachine::HonestMajorityFieldMachine(int argc, const char **argv) @@ -34,6 +33,7 @@ template class T, template class V, class W, class X> FieldMachine::FieldMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) { + assert(nplayers or T::variable_players); W machine(argc, argv, opt, online_opts, X(), nplayers); int n_limbs = online_opts.prime_limbs(); switch (n_limbs) diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 7cb309fc2..999fb9eaa 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -10,11 +10,13 @@ #include "Math/gf2n.h" #include "GC/instructions.h" +#include "Memory.hpp" + #include template void Instruction::execute_clear_gf2n(vector& registers, - vector& memory, ArithmeticProcessor& Proc) const + MemoryPart& memory, ArithmeticProcessor& Proc) const { auto& C2 = registers; auto& M2C = memory; @@ -123,6 +125,6 @@ ostream& operator<<(ostream& s, const Instruction& instr) } template void Instruction::execute_clear_gf2n(vector& registers, - vector& memory, ArithmeticProcessor& Proc) const; + MemoryPart& memory, ArithmeticProcessor& Proc) const; template void Instruction::execute_clear_gf2n(vector& registers, - vector& memory, ArithmeticProcessor& Proc) const; + MemoryPart& memory, ArithmeticProcessor& Proc) const; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 36ffbed57..0121b6d14 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -72,6 +72,7 @@ enum USE_EDABIT = 0xE5, USE_MATMUL = 0x1F, ACTIVE = 0xE9, + CMDLINEARG = 0xEB, // Addition ADDC = 0x20, ADDS = 0x21, @@ -153,7 +154,7 @@ enum LISTEN = 0x6c, ACCEPTCLIENTCONNECTION = 0x6d, CLOSECLIENTCONNECTION = 0x6e, - READCLIENTPUBLICKEY = 0x6f, + INITCLIENTCONNECTION = 0x6f, // Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -197,6 +198,7 @@ enum PRINTREG = 0XB1, RAND = 0xB2, PRINTREGPLAIN = 0xB3, + PRINTREGPLAINS = 0xEA, PRINTCHR = 0xB4, PRINTSTR = 0xB5, PUBINPUT = 0xB6, @@ -345,6 +347,7 @@ class BaseInstruction int r[4]; // Fixed parameter registers size_t n; // Possible immediate value vector start; // Values for a start/stop open + string str; public: virtual ~BaseInstruction() {}; @@ -387,7 +390,7 @@ class Instruction : public BaseInstruction void execute(Processor& Proc) const; template - void execute_clear_gf2n(vector& registers, vector& memory, + void execute_clear_gf2n(vector& registers, MemoryPart& memory, ArithmeticProcessor& Proc) const; template diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index bb555281b..c91cf6314 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -105,7 +105,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STMCBI: case MOVC: case MOVS: - case MOVSB: case MOVINT: case LDMINTI: case STMINTI: @@ -131,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case SHUFFLE: case ACCEPTCLIENTCONNECTION: case PREFIXSUMS: + case CMDLINEARG: get_ints(r, s, 2); break; // instructions with 1 register operand @@ -139,6 +139,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RANDOMFULLS: case PRINTREGPLAIN: case PRINTREGPLAINB: + case PRINTREGPLAINS: case LDTN: case LDARG: case STARG: @@ -316,13 +317,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case TRUNC_PR: case RUN_TAPE: case CONV2DS: + case MATMULS: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; - case MATMULS: - get_ints(r, s, 3); - get_vector(3, start, s); - break; case MATMULSM: get_ints(r, s, 3); get_vector(9, start, s); @@ -358,7 +356,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) n = get_int(s); get_vector(num_var_args, start, s); break; - case READCLIENTPUBLICKEY: + case INITCLIENTCONNECTION: + get_ints(r, s, 3); + get_string(str, s); + break; case INITSECURESOCKET: case RESPSECURESOCKET: throw runtime_error("VM-controlled encryption not supported any more"); @@ -459,6 +460,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CONVCBIT2S: case NOTS: case NOTCB: + case MOVSB: n = get_int(s); get_ints(r, s, 2); break; @@ -566,7 +568,7 @@ int BaseInstruction::get_reg_type() const case MOVINT: case READSOCKETINT: case WRITESOCKETINT: - case READCLIENTPUBLICKEY: + case INITCLIENTCONNECTION: case INITSECURESOCKET: case RESPSECURESOCKET: case LDARG: @@ -584,6 +586,7 @@ int BaseInstruction::get_reg_type() const case INTOUTPUT: case ACCEPTCLIENTCONNECTION: case GENSECSHUFFLE: + case CMDLINEARG: return INT; case PREP: case GPREP: @@ -723,6 +726,15 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return res; } case MATMULS: + { + int res = 0; + for (auto it = start.begin(); it < start.end(); it += 6) + { + int tmp = *it + *(it + 3) * *(it + 5); + res = max(res, tmp); + } + return res; + } case MATMULSM: return r[0] + start[0] * start[2]; case CONV2DS: @@ -817,7 +829,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const while (it < start.end()) { int n = *it - n_prefix; - int size = DIV_CEIL(*(it + 1), 64); + size = max((long long) size, DIV_CEIL(*(it + 1), 64)); it += n_prefix; assert(it + n <= start.end()); for (int i = 0; i < n; i++) @@ -922,16 +934,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n)); n++; break; - case LDMCI: - Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1]))); - break; case STMC: Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0])); n++; break; - case STMCI: - Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0])); - break; case MOVC: Proc.write_Cp(r[0],Proc.read_Cp(r[1])); break; @@ -1089,10 +1095,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.Proc2.POpen(*this); return; case MULS: - Proc.Procp.muls(start, size); + Proc.Procp.muls(start); return; case GMULS: - Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size); + Proc.Proc2.muls(start); return; case MULRS: Proc.Procp.mulrs(start); @@ -1107,7 +1113,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.Proc2.dotprods(start, size); return; case MATMULS: - Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]); + Proc.Procp.matmuls(Proc.Procp.get_S(), *this); return; case MATMULSM: Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this, @@ -1126,13 +1132,15 @@ inline void Instruction::execute(Processor& Proc) const Proc.Proc2.secure_shuffle(*this); return; case GENSECSHUFFLE: - Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this)); + Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this, + Proc.machine.shuffle_store)); return; case APPLYSHUFFLE: - Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3))); + Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)), + Proc.machine.shuffle_store); return; case DELSHUFFLE: - Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); + Proc.machine.shuffle_store.del(Proc.read_Ci(r[0])); return; case INVPERM: Proc.Procp.inverse_permutation(*this); @@ -1170,6 +1178,9 @@ inline void Instruction::execute(Processor& Proc) const case PRINTREGPLAIN: print(Proc.out, &Proc.read_Cp(r[0])); return; + case PRINTREGPLAINS: + Proc.out << Proc.read_Sp(r[0]); + return; case CONDPRINTPLAIN: if (not Proc.read_Cp(r[0]).is_zero()) { @@ -1237,6 +1248,19 @@ inline void Instruction::execute(Processor& Proc) const case PLAYERID: Proc.write_Ci(r[0], Proc.P.my_num()); break; + case CMDLINEARG: + { + size_t idx = Proc.read_Ci(r[1]); + auto& args = OnlineOptions::singleton.args; + if (idx < args.size()) + Proc.write_Ci(r[0], args[idx]); + else + { + cerr << idx << "-th command-line argument not given" << endl; + exit(1); + } + break; + } // *** // TODO: read/write shared GF(2^n) data instructions // *** @@ -1255,11 +1279,17 @@ inline void Instruction::execute(Processor& Proc) const octetStream os; os.store(int(sint::open_type::type_char())); sint::specification(os); + sint::clear::specification(os); os.Send(Proc.external_clients.get_socket(client_handle)); } Proc.write_Ci(r[0], client_handle); break; } + case INITCLIENTCONNECTION: + Proc.write_Ci(r[0], + Proc.external_clients.init_client_connection(str, + Proc.read_Ci(r[1]), Proc.read_Ci(r[2]))); + break; case CLOSECLIENTCONNECTION: Proc.external_clients.close_connection(Proc.read_Ci(r[0])); break; diff --git a/Processor/Machine.h b/Processor/Machine.h index 803e3d919..9c27fa4b2 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -20,6 +20,8 @@ #include "Tools/time-func.h" #include "Tools/ExecutionStats.h" +#include "Protocols/SecureShuffle.h" + #include #include #include @@ -70,6 +72,8 @@ class Machine : public BaseMachine ExternalClients external_clients; + typename sint::Protocol::Shuffler::store_type shuffle_store; + static void init_binary_domains(int security_parameter, int lg2); Machine(Names& playerNames, bool use_encryption = true, diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index d9c245819..1bb285eb3 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -60,11 +60,21 @@ Machine::Machine(Names& playerNames, bool use_encryption, { OnlineOptions::singleton = opts; - if (N.num_players() == 1 and sint::is_real) + int min_players = 3 - sint::dishonest_majority; + if (sint::is_real) { - cerr << "Need more than one player to run a protocol." << endl; - cerr << "Use 'emulate.x' for just running the virtual machine" << endl; - exit(1); + if (N.num_players() == 1) + { + cerr << "Need more than one player to run a protocol." << endl; + cerr << "Use 'emulate.x' for just running the virtual machine" << endl; + exit(1); + } + else if (N.num_players() < min_players) + { + cerr << "Need at least " << min_players << " players for this protocol." + << endl; + exit(1); + } } // Set the prime modulus from command line or program if applicable @@ -480,8 +490,10 @@ void Machine::run(const string& progname) if (opts.verbose) { - cerr << "Communication details " - "(rounds in parallel threads counted double):" << endl; + cerr << "Communication details"; + if (multithread) + cerr << " (rounds in parallel threads counted double)"; + cerr << ":" << endl; comm_stats.print(); cerr << "CPU time = " << proc_timer.elapsed(); if (multithread) @@ -547,6 +559,14 @@ void Machine::run(const string& progname) suggest_optimizations(); + if (N.num_players() > 4) + { + string alt = sint::alt(); + if (alt.size()) + cerr << "This protocol doesn't scale well with the number of parties, " + << "have you considered using " << alt << " instead?" << endl; + } + #ifdef VERBOSE cerr << "End of prog" << endl; #endif diff --git a/Processor/Memory.h b/Processor/Memory.h index 16e885485..649c745b1 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -14,34 +14,90 @@ template istream& operator>>(istream& s,Memory& M); #include "Processor/Program.h" #include "Tools/CheckVector.h" +#include "Tools/DiskVector.h" template -class MemoryPart : public CheckVector +class MemoryPart { public: - template - static void check_index(const vector& M, size_t i) + virtual ~MemoryPart() {} + + virtual size_t size() const = 0; + virtual void resize(size_t) = 0; + + virtual T* data() = 0; + virtual const T* data() const = 0; + + void check_index(size_t i) const { - (void) M, (void) i; + (void) i; #ifndef NO_CHECK_INDEX - if (i >= M.size()) - throw overflow(U::type_string() + " memory", i, M.size()); + if (i >= this->size()) + throw overflow(T::type_string() + " memory", i, this->size()); #endif } + virtual T& operator[](size_t i) = 0; + virtual const T& operator[](size_t i) const = 0; + + virtual T& at(size_t i) = 0; + virtual const T& at(size_t i) const = 0; + + template + void indirect_read(const Instruction& inst, vector& regs, + const U& indices); + template + void indirect_write(const Instruction& inst, vector& regs, + const U& indices); + + void minimum_size(size_t size); +}; + +template class V> +class MemoryPartImpl : public MemoryPart, public V +{ +public: + size_t size() const + { + return V::size(); + } + + void resize(size_t size) + { + V::resize(size); + } + + T* data() + { + return V::data(); + } + + const T* data() const + { + return V::data(); + } + T& operator[](size_t i) { - check_index(*this, i); - return CheckVector::operator[](i); + this->check_index(i); + return V::operator[](i); } const T& operator[](size_t i) const { - check_index(*this, i); - return CheckVector::operator[](i); + this->check_index(i); + return V::operator[](i); } - void minimum_size(size_t size); + T& at(size_t i) + { + return V::at(i); + } + + const T& at(size_t i) const + { + return V::at(i); + } }; template @@ -49,8 +105,11 @@ class Memory { public: - MemoryPart MS; - MemoryPart MC; + MemoryPart& MS; + MemoryPartImpl MC; + + Memory(); + ~Memory(); void resize_s(size_t sz) { MS.resize(sz); } diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index ef767441b..23a41b1fa 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -3,6 +3,54 @@ #include +template +template +void MemoryPart::indirect_read(const Instruction& inst, + vector& regs, const U& indices) +{ + size_t n = inst.get_size(); + auto dest = regs.begin() + inst.get_r(0); + auto start = indices.begin() + inst.get_r(1); +#ifdef CHECK_SIZE + assert(start + n <= indices.end()); + assert(dest + n <= regs.end()); +#endif + long size = this->size(); + const T* data = this->data(); + for (auto it = start; it < start + n; it++) + { +#ifndef NO_CHECK_SIZE + if (*it >= size) + throw overflow(T::type_string() + " memory read", it->get(), size); +#endif + *dest++ = data[it->get()]; + } +} + +template +template +void MemoryPart::indirect_write(const Instruction& inst, + vector& regs, const U& indices) +{ + size_t n = inst.get_size(); + auto source = regs.begin() + inst.get_r(0); + auto start = indices.begin() + inst.get_r(1); +#ifdef CHECK_SIZE + assert(start + n <= indices.end()); + assert(source + n <= regs.end()); +#endif + long size = this->size(); + T* data = this->data(); + for (auto it = start; it < start + n; it++) + { +#ifndef NO_CHECK_SIZE + if (*it >= size) + throw overflow(T::type_string() + " memory write", it->get(), size); +#endif + data[it->get()] = *source++; + } +} + template void Memory::minimum_size(RegType secret_type, RegType clear_type, const Program &program, const string& threadname) @@ -29,6 +77,21 @@ void MemoryPart::minimum_size(size_t size) } } +template +Memory::Memory() : + MS( + *(OnlineOptions::singleton.disk_memory.size() ? + static_cast*>(new MemoryPartImpl) : + static_cast*>(new MemoryPartImpl))) +{ +} + +template +Memory::~Memory() +{ + delete &MS; +} + template ostream& operator<<(ostream& s,const Memory& M) { diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 2792198f5..078a39199 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -71,18 +71,6 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op "--ip-file-name" // Flag token. ); - if (nplayers == 0) - opt.add( - "2", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Number of players (default: 2). " - "Ignored if external server is used.", // Help description. - "-N", // Flag token. - "--nparties" // Flag token. - ); - opt.add( "", // Default. 0, // Required? diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index b4bf6594e..ad0a0403d 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -22,12 +22,13 @@ OnlineOptions::OnlineOptions() : playerno(-1) interactive = false; lgp = gfp0::MAX_N_BITS; live_prep = true; - batch_size = 10000; + batch_size = 1000; memtype = "empty"; bits_from_squares = false; direct = false; bucket_size = 4; security_parameter = DEFAULT_SECURITY; + use_security_parameter = false; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; @@ -46,6 +47,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, bool security) : OnlineOptions() { + use_security_parameter = security; + opt.syntax = std::string(argv[0]) + " [OPTIONS] [] "; opt.add( @@ -116,7 +119,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - ("Security parameter (default: " + to_string(security_parameter) + ("Statistical ecurity parameter (default: " + to_string(security_parameter) + ")").c_str(), // Help description. "-S", // Flag token. "--security" // Flag token. @@ -138,7 +141,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, if (security) { opt.get("-S")->getInt(security_parameter); - cerr << "Using security parameter " << security_parameter << endl; if (security_parameter <= 0) { cerr << "Invalid security parameter: " << security_parameter << endl; @@ -280,7 +282,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, } void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, - const char** argv) + const char** argv, bool networking) { opt.resetArgs(); opt.parse(argc, argv); @@ -292,17 +294,21 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, vector badOptions; unsigned int i; - opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html " - "for documentation on the networking setup.\n"; + if (networking) + opt.footer += "See also " + "https://mp-spdz.readthedocs.io/en/latest/networking.html " + "for documentation on the networking setup.\n\n"; + + size_t name_index = 1 + networking - opt.isSet("-p"); - if (allArgs.size() != 3u - opt.isSet("-p")) + if (allArgs.size() < name_index + 1) { + opt.getUsage(usage); + cout << usage; cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; cerr << "Arguments given were:\n"; for (unsigned int j = 1; j < allArgs.size(); j++) cout << "'" << *allArgs[j] << "'" << endl; - opt.getUsage(usage); - cout << usage; exit(1); } else @@ -311,25 +317,25 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, opt.get("-p")->getInt(playerno); else sscanf((*allArgs[1]).c_str(), "%d", &playerno); - progname = *allArgs[2 - opt.isSet("-p")]; + progname = *allArgs.at(name_index); } if (!opt.gotRequired(badOptions)) { - for (i = 0; i < badOptions.size(); ++i) - cerr << "ERROR: Missing required option " << badOptions[i] << "."; opt.getUsage(usage); cout << usage; + for (i = 0; i < badOptions.size(); ++i) + cerr << "ERROR: Missing required option " << badOptions[i] << "."; exit(1); } if (!opt.gotExpected(badOptions)) { + opt.getUsage(usage); + cout << usage; for (i = 0; i < badOptions.size(); ++i) cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << "."; - opt.getUsage(usage); - cout << usage; exit(1); } @@ -347,6 +353,22 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, prime = schedule_prime; } + for (size_t i = name_index + 1; i < allArgs.size(); i++) + { + try + { + args.push_back(stol(*allArgs[i])); + } + catch (exception& e) + { + opt.getUsage(usage); + cerr << usage; + cerr << "Additional argument has to be integer: " << *allArgs[i] + << endl; + exit(1); + } + } + // ignore program if length explicitly set from command line if (opt.get("-lgp") and not opt.isSet("-lgp")) { @@ -367,7 +389,29 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, if (o) o->getInt(max_broadcast); + o = opt.get("--disk-memory"); + if (o) + o->getString(disk_memory); + receive_threads = opt.isSet("--threads"); + + if (use_security_parameter) + { + int program_sec = BaseMachine::security_from_schedule(progname); + + if (program_sec > 0) + { + if (not opt.isSet("-S")) + security_parameter = program_sec; + if (program_sec < security_parameter) + { + cerr << "Security parameter used in compilation is insufficient" << endl; + exit(1); + } + } + + cerr << "Using statistical security parameter " << security_parameter << endl; + } } void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 61c1352bc..7e32c3176 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -27,6 +27,7 @@ class OnlineOptions bool direct; int bucket_size; int security_parameter; + bool use_security_parameter; std::string cmd_private_input_file; std::string cmd_private_output_file; bool verbose; @@ -34,6 +35,8 @@ class OnlineOptions int trunc_error; int opening_sum, max_broadcast; bool receive_threads; + std::string disk_memory; + vector args; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -48,7 +51,8 @@ class OnlineOptions OnlineOptions(T); ~OnlineOptions() {} - void finalize(ez::ezOptionParser& opt, int argc, const char** argv); + void finalize(ez::ezOptionParser& opt, int argc, const char** argv, + bool networking = true); void set_trunc_error(ez::ezOptionParser& opt); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp index d8b71cea6..c822b5a8e 100644 --- a/Processor/OnlineOptions.hpp +++ b/Processor/OnlineOptions.hpp @@ -11,7 +11,7 @@ template OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, bool default_live_prep) : - OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0, + OnlineOptions(opt, argc, argv, OnlineOptions(T()).batch_size, default_live_prep, T::clear::prime_field) { if (T::has_trunc_pr) @@ -56,13 +56,39 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--max-broadcast" // Flag token. ); } + + if (not T::clear::binary) + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use directory on disk for memory (container data structures) " + "instead of RAM", // Help description. + "-D", // Flag token. + "--disk-memory" // Flag token. + ); + + if (T::variable_players) + opt.add( + T::dishonest_majority ? "2" : "3", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + ("Number of players (default: " + + (T::dishonest_majority ? + to_string("2") : to_string("3")) + "). " + + "Ignored if external server is used.").c_str(), // Help description. + "-N", // Flag token. + "--nparties" // Flag token. + ); } template OnlineOptions::OnlineOptions(T) : OnlineOptions() { - if (T::dishonest_majority) - batch_size = 1000; + if (not T::dishonest_majority) + batch_size = 10000; } #endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index f6ffade98..9b4757f4e 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -36,7 +36,7 @@ class SubProcessor void resize(size_t size) { C.resize(size); S.resize(size); } - void matmulsm_prep(int ii, int j, const CheckVector& source, + void matmulsm_prep(int ii, int j, const MemoryPart& source, const vector& dim, size_t a, size_t b); void matmulsm_finalize(int i, int j, const vector& dim, typename vector::iterator C); @@ -48,6 +48,8 @@ class SubProcessor typedef typename T::bit_type::part_type BT; + typedef typename T::Protocol::Shuffler::store_type ShuffleStore; + public: ArithmeticProcessor* Proc; typename T::MAC_Check& MC; @@ -71,19 +73,19 @@ class SubProcessor // Access to PO (via calls to POpen start/stop) void POpen(const Instruction& inst); - void muls(const vector& reg, int size); + void muls(const vector& reg); void mulrs(const vector& reg); void dotprods(const vector& reg, int size); - void matmuls(const vector& source, const Instruction& instruction, size_t a, - size_t b); - void matmulsm(const CheckVector& source, const Instruction& instruction, size_t a, + void matmuls(const vector& source, const Instruction& instruction); + void matmulsm(const MemoryPart& source, const Instruction& instruction, size_t a, size_t b); void conv2ds(const Instruction& instruction); void secure_shuffle(const Instruction& instruction); - size_t generate_secure_shuffle(const Instruction& instruction); - void apply_shuffle(const Instruction& instruction, int handle); - void delete_shuffle(int handle); + size_t generate_secure_shuffle(const Instruction& instruction, + ShuffleStore& shuffle_store); + void apply_shuffle(const Instruction& instruction, int handle, + ShuffleStore& shuffle_store); void inverse_permutation(const Instruction& instruction); void input_personal(const vector& args); @@ -116,7 +118,7 @@ class SubProcessor class ArithmeticProcessor : public ProcessorBase { protected: - CheckVector Ci; + CheckVector Ci; ofstream public_output; ofstream binary_output; @@ -162,13 +164,13 @@ class ArithmeticProcessor : public ProcessorBase return thread_num; } - const long& read_Ci(size_t i) const - { return Ci[i]; } - long& get_Ci_ref(size_t i) + long read_Ci(size_t i) const + { return Ci[i].get(); } + Integer& get_Ci_ref(size_t i) { return Ci[i]; } void write_Ci(size_t i, const long& x) { Ci[i]=x; } - CheckVector& get_Ci() + CheckVector& get_Ci() { return Ci; } virtual ofstream& get_public_output() diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d2fe438cf..23c080ccf 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -379,9 +379,20 @@ void Processor::read_socket_private(int client_id, client_timer.stop(); client_stats.add(socket_stream.get_length()); - for (int j = 0; j < size; j++) - for (int i = 0; i < m; i++) - get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs); + int j, i; + try + { + for (j = 0; j < size; j++) + for (i = 0; i < m; i++) + get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs); + } + catch (exception& e) + { + throw insufficient_shares(m * size, j * m + i, e); + } + + if (socket_stream.left()) + throw runtime_error("unexpected share data"); } @@ -468,28 +479,29 @@ void SubProcessor::POpen(const Instruction& inst) } template -void SubProcessor::muls(const vector& reg, int size) +void SubProcessor::muls(const vector& reg) { - assert(reg.size() % 3 == 0); - int n = reg.size() / 3; + assert(reg.size() % 4 == 0); + int n = reg.size() / 4; SubProcessor& proc = *this; protocol.init_mul(); for (int i = 0; i < n; i++) - for (int j = 0; j < size; j++) + for (int j = 0; j < reg[4 * i]; j++) { - auto& x = proc.S[reg[3 * i + 1] + j]; - auto& y = proc.S[reg[3 * i + 2] + j]; + auto& x = proc.S[reg[4 * i + 2] + j]; + auto& y = proc.S[reg[4 * i + 3] + j]; protocol.prepare_mul(x, y); } protocol.exchange(); for (int i = 0; i < n; i++) - for (int j = 0; j < size; j++) + { + for (int j = 0; j < reg[4 * i]; j++) { - proc.S[reg[3 * i] + j] = protocol.finalize_mul(); + proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul(); } - - protocol.counter += n * size; + protocol.counter += n * reg[4 * i]; + } } template @@ -553,33 +565,46 @@ void SubProcessor::dotprods(const vector& reg, int size) template void SubProcessor::matmuls(const vector& source, - const Instruction& instruction, size_t a, size_t b) + const Instruction& instruction) { - auto& dim = instruction.get_start(); - auto A = source.begin() + a; - auto B = source.begin() + b; - auto C = S.begin() + (instruction.get_r(0)); - assert(A + dim[0] * dim[1] <= source.end()); - assert(B + dim[1] * dim[2] <= source.end()); - assert(C + dim[0] * dim[2] <= S.end()); - protocol.init_dotprod(); - for (int i = 0; i < dim[0]; i++) - for (int j = 0; j < dim[2]; j++) - { - for (int k = 0; k < dim[1]; k++) - protocol.prepare_dotprod(*(A + i * dim[1] + k), - *(B + k * dim[2] + j)); - protocol.next_dotprod(); - } + + auto& start = instruction.get_start(); + assert(start.size() % 6 == 0); + + for(auto it = start.begin(); it < start.end(); it += 6) + { + auto dim = it + 3; + auto A = source.begin() + *(it + 1); + auto B = source.begin() + *(it + 2); + assert(A + dim[0] * dim[1] <= source.end()); + assert(B + dim[1] * dim[2] <= source.end()); + + for (int i = 0; i < dim[0]; i++) + for (int j = 0; j < dim[2]; j++) + { + for (int k = 0; k < dim[1]; k++) + protocol.prepare_dotprod(*(A + i * dim[1] + k), + *(B + k * dim[2] + j)); + protocol.next_dotprod(); + } + } + protocol.exchange(); - for (int i = 0; i < dim[0]; i++) - for (int j = 0; j < dim[2]; j++) - *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); + + for(auto it = start.begin(); it < start.end(); it += 6) + { + auto C = S.begin() + *it; + auto dim = it + 3; + assert(C + dim[0] * dim[2] <= S.end()); + for (int i = 0; i < dim[0]; i++) + for (int j = 0; j < dim[2]; j++) + *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); + } } template -void SubProcessor::matmulsm(const CheckVector& source, +void SubProcessor::matmulsm(const MemoryPart& source, const Instruction& instruction, size_t a, size_t b) { auto& dim = instruction.get_start(); @@ -592,7 +617,7 @@ void SubProcessor::matmulsm(const CheckVector& source, protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) { - auto ii = Proc->get_Ci().at(dim[3] + i); + auto ii = Proc->get_Ci().at(dim[3] + i).get(); for (int j = 0; j < dim[2]; j++) { #ifdef DEBUG_MATMULSM @@ -628,16 +653,21 @@ void SubProcessor::matmulsm(const CheckVector& source, } template -void SubProcessor::matmulsm_prep(int ii, int j, const CheckVector& source, +void SubProcessor::matmulsm_prep(int ii, int j, const MemoryPart& source, const vector& dim, size_t a, size_t b) { - auto jj = Proc->get_Ci().at(dim[6] + j); + auto jj = Proc->get_Ci().at(dim[6] + j).get(); + const T* base = source.data(); + size_t size = source.size(); for (int k = 0; k < dim[1]; k++) { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ll = Proc->get_Ci().at(dim[5] + k); - protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk), - source.at(b + ll * dim[8] + jj)); + auto kk = Proc->get_Ci().at(dim[4] + k).get(); + auto ll = Proc->get_Ci().at(dim[5] + k).get(); + auto aa = a + ii * dim[7] + kk; + auto bb = b + ll * dim[8] + jj; + assert(aa < size); + assert(bb < size); + protocol.prepare_dotprod(base[aa], base[bb]); } protocol.next_dotprod(); } @@ -655,16 +685,22 @@ void SubProcessor::matmulsm_finalize(int i, int j, const vector& dim, template void SubProcessor::conv2ds(const Instruction& instruction) { - protocol.init_dotprod(); auto& args = instruction.get_start(); vector tuples; for (size_t i = 0; i < args.size(); i += 15) tuples.push_back(Conv2dTuple(args, i)); - for (auto& tuple : tuples) - tuple.pre(S, protocol); - protocol.exchange(); - for (auto& tuple : tuples) - tuple.post(S, protocol); + size_t done = 0; + while (done < tuples.size()) + { + protocol.init_dotprod(); + size_t i; + for (i = done; i < tuples.size() and protocol.get_buffer_size() < + OnlineOptions::singleton.batch_size; i++) + tuples[i].pre(S, protocol); + protocol.exchange(); + for (; done < i; done++) + tuples[done].post(S, protocol); + } } inline @@ -766,25 +802,22 @@ void SubProcessor::secure_shuffle(const Instruction& instruction) } template -size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction) +size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction, + ShuffleStore& shuffle_store) { - return shuffler.generate(instruction.get_n()); + return shuffler.generate(instruction.get_n(), shuffle_store); } template -void SubProcessor::apply_shuffle(const Instruction& instruction, int handle) +void SubProcessor::apply_shuffle(const Instruction& instruction, int handle, + ShuffleStore& shuffle_store) { shuffler.apply(S, instruction.get_size(), instruction.get_start()[2], - instruction.get_start()[0], instruction.get_start()[1], handle, + instruction.get_start()[0], instruction.get_start()[1], + shuffle_store.get(handle), instruction.get_start()[4]); } -template -void SubProcessor::delete_shuffle(int handle) -{ - shuffler.del(handle); -} - template void SubProcessor::inverse_permutation(const Instruction& instruction) { shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0], @@ -796,17 +829,26 @@ void SubProcessor::input_personal(const vector& args) { input.reset_all(P); for (size_t i = 0; i < args.size(); i += 4) - for (int j = 0; j < args[i]; j++) + if (args[i + 1] == P.my_num()) { - if (args[i + 1] == P.my_num()) - input.add_mine(C[args[i + 3] + j]); - else - input.add_other(args[i + 1]); + auto begin = C.begin() + args[i + 3]; + auto end = begin + args[i]; + assert(end <= C.end()); + for (auto it = begin; it < end; it++) + input.add_mine(*it); } + else + for (int j = 0; j < args[i]; j++) + input.add_other(args[i + 1]); input.exchange(); for (size_t i = 0; i < args.size(); i += 4) - for (int j = 0; j < args[i]; j++) - S[args[i + 2] + j] = input.finalize(args[i + 1]); + { + auto begin = S.begin() + args[i + 2]; + auto end = begin + args[i]; + assert(end <= S.end()); + for (auto it = begin; it < end; it++) + *it = input.finalize(args[i + 1]); + } } /** @@ -858,6 +900,16 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } +template +void fixinput_int(T& proc, const Instruction& instruction, U) +{ + U* x = new U[instruction.get_size()]; + proc.binary_input.read((char*) x, sizeof(U) * instruction.get_size()); + for (int i = 0; i < instruction.get_size(); i++) + proc.write_Cp(instruction.get_r(0) + i, x[i]); + delete[] x; +} + template void Processor::fixinput(const Instruction& instruction) { @@ -878,19 +930,24 @@ void Processor::fixinput(const Instruction& instruction) throw runtime_error("unknown format for fixed-point input"); } - for (int i = 0; i < instruction.get_size(); i++) + if (binary_input.fail()) + throw IO_Error("failure reading from " + binary_input_filename); + + if (binary_input.peek() == EOF) + throw IO_Error("not enough inputs in " + binary_input_filename); + + if (instruction.get_r(2) == 0) { - if (binary_input.peek() == EOF) - throw IO_Error("not enough inputs in " + binary_input_filename); - double buf; - if (instruction.get_r(2) == 0) - { - int64_t x; - binary_input.read((char*) &x, sizeof(x)); - tmp = x; - } + if (instruction.get_r(1) == 1) + fixinput_int(*this, instruction, int8_t()); else + fixinput_int(*this, instruction, int64_t()); + } + else + { + for (int i = 0; i < instruction.get_size(); i++) { + double buf; if (use_double) binary_input.read((char*) &buf, sizeof(double)); else @@ -900,11 +957,12 @@ void Processor::fixinput(const Instruction& instruction) buf = x; } tmp = bigint::tmp = round(buf * exp2(instruction.get_r(1))); + write_Cp(instruction.get_r(0) + i, tmp); } - if (binary_input.fail()) - throw IO_Error("failure reading from " + binary_input_filename); - write_Cp(instruction.get_r(0) + i, tmp); } + + if (binary_input.fail()) + throw IO_Error("failure reading from " + binary_input_filename); } } diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h index d30de5d30..d33dea42f 100644 --- a/Processor/ProcessorBase.h +++ b/Processor/ProcessorBase.h @@ -14,11 +14,12 @@ using namespace std; #include "Tools/ExecutionStats.h" #include "Tools/SwitchableOutput.h" #include "OnlineOptions.h" +#include "Math/Integer.h" class ProcessorBase { // Stack - stack stacki; + stack stacki; ifstream input_file; string input_filename; @@ -26,7 +27,7 @@ class ProcessorBase protected: // Optional argument to tape - int arg; + Integer arg; string get_parameterized_filename(int my_num, int thread_num, const string& prefix); @@ -38,15 +39,15 @@ class ProcessorBase ProcessorBase(); - void pushi(long x) { stacki.push(x); } - void popi(long& x) { x = stacki.top(); stacki.pop(); } + void pushi(Integer x) { stacki.push(x); } + void popi(Integer& x) { x = stacki.top(); stacki.pop(); } - int get_arg() const + Integer get_arg() const { return arg; } - void set_arg(int new_arg) + void set_arg(Integer new_arg) { arg=new_arg; } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index e39e22553..859fbd929 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -41,6 +41,7 @@ template class U, template class V, class W> RingMachine::RingMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) { + assert(nplayers or U<64>::variable_players); RingOptions opts(opt, argc, argv); W machine(argc, argv, opt, online_opts, gf2n(), nplayers); int R = opts.ring_size_from_opts_or_schedule(online_opts.progname); @@ -65,7 +66,7 @@ template class U, template class V> HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecurity( int argc, const char** argv, ez::ezOptionParser& opt) { - OnlineOptions online_opts(opt, argc, argv); + OnlineOptions online_opts(opt, argc, argv, U<64, 40>()); RingOptions opts(opt, argc, argv); HonestMajorityMachine machine(argc, argv, opt, online_opts); int R = opts.ring_size_from_opts_or_schedule(online_opts.progname); diff --git a/Processor/instructions.h b/Processor/instructions.h index 756bbf7ca..8db1557e6 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -18,10 +18,10 @@ *dest++ = *source++) \ X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \ *dest++ = *source++) \ - X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \ - *dest++ = Proc.machine.Mp.read_S(*source++)) \ - X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \ - Proc.machine.Mp.write_S(*dest++, *source++)) \ + X(LDMSI, Proc.machine.Mp.MS.indirect_read(instruction, Procp.get_S(), Proc.get_Ci()),) \ + X(STMSI, Proc.machine.Mp.MS.indirect_write(instruction, Procp.get_S(), Proc.get_Ci()),) \ + X(LDMCI, Proc.machine.Mp.MC.indirect_read(instruction, Procp.get_C(), Proc.get_Ci()),) \ + X(STMCI, Proc.machine.Mp.MC.indirect_write(instruction, Procp.get_C(), Proc.get_Ci()),) \ X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \ *dest++ = *source++) \ X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ @@ -121,10 +121,8 @@ *dest++ = *source++) \ X(GSTMS, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.machine.M2.MS[n], \ *dest++ = *source++) \ - X(GLDMSI, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \ - *dest++ = Proc.machine.M2.read_S(*source++)) \ - X(GSTMSI, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \ - Proc.machine.M2.write_S(*dest++, *source++)) \ + X(GLDMSI, Proc.machine.M2.MS.indirect_read(instruction, Proc2.get_S(), Proc.get_Ci()),) \ + X(GSTMSI, Proc.machine.M2.MS.indirect_write(instruction, Proc2.get_S(), Proc.get_Ci()),) \ X(GMOVS, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc2.get_S()[r[1]], \ *dest++ = *source++) \ X(GADDS, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \ @@ -171,10 +169,8 @@ *dest++ = (*source).get(); source++) \ X(STMINT, auto dest = &Mi[n]; auto source = &Proc.get_Ci()[r[0]], \ *dest++ = *source++) \ - X(LDMINTI, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \ - *dest++ = Mi[*source].get(); source++) \ - X(STMINTI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &Ci[r[0]], \ - Mi[*dest] = *source++; dest++) \ + X(LDMINTI, Mi.indirect_read(*this, Proc.get_Ci(), Proc.get_Ci()),) \ + X(STMINTI, Mi.indirect_write(*this, Proc.get_Ci(), Proc.get_Ci()),) \ X(MOVINT, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \ *dest++ = *source++) \ X(PUSHINT, Proc.pushi(Ci[r[0]]),) \ @@ -213,7 +209,7 @@ X(SHUFFLE, shuffle(Proc),) \ X(BITDECINT, bitdecint(Proc),) \ X(RAND, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], \ - *dest++ = Proc.shared_prng.get_uint() % (1 << *source++)) \ + *dest++ = Proc.shared_prng.get_uint() % (1 << (*source++).get())) \ #define CLEAR_GF2N_INSTRUCTIONS \ X(GLDI, auto dest = &C2[r[0]]; cgf2n tmp = int(n), \ @@ -222,10 +218,8 @@ *dest++ = (*source).get(); source++) \ X(GSTMC, auto dest = &M2C[n]; auto source = &C2[r[0]], \ *dest++ = *source++) \ - X(GLDMCI, auto dest = &C2[r[0]]; auto source = &Proc.get_Ci()[r[1]], \ - *dest++ = M2C[*source++]) \ - X(GSTMCI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &C2[r[0]], \ - M2C[*dest++] = *source++) \ + X(GLDMCI, M2C.indirect_read(*this, C2, Proc.get_Ci()),) \ + X(GSTMCI, M2C.indirect_write(*this, C2, Proc.get_Ci()),) \ X(GMOVC, auto dest = &C2[r[0]]; auto source = &C2[r[1]], \ *dest++ = *source++) \ X(GADDC, auto dest = &C2[r[0]]; auto op1 = &C2[r[1]]; \ @@ -288,9 +282,7 @@ #define REMAINING_INSTRUCTIONS \ X(CONVMODP, throw not_implemented(),) \ X(LDMC, throw not_implemented(),) \ - X(LDMCI, throw not_implemented(),) \ X(STMC, throw not_implemented(),) \ - X(STMCI, throw not_implemented(),) \ X(MOVC, throw not_implemented(),) \ X(DIVC, throw not_implemented(),) \ X(GDIVC, throw not_implemented(),) \ @@ -390,6 +382,8 @@ X(APPLYSHUFFLE, throw not_implemented(),) \ X(DELSHUFFLE, throw not_implemented(),) \ X(ACTIVE, throw not_implemented(),) \ + X(FIXINPUT, throw not_implemented(),) \ + X(CONCATS, throw not_implemented(),) \ #define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS diff --git a/Programs/Source/mnist_full_B.mpc b/Programs/Source/mnist_full_B.mpc index 41d83f313..61692fcf1 100644 --- a/Programs/Source/mnist_full_B.mpc +++ b/Programs/Source/mnist_full_B.mpc @@ -51,7 +51,7 @@ except: pass layers = [ - ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [n_examples, 24, 24, 16], + ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [N, 24, 24, 16], (1, 1), 'VALID'), ml.MaxPool([N, 24, 24, 16]), ml.Relu([N, 12, 12, 16]), diff --git a/Programs/Source/multinode_example_main.py b/Programs/Source/multinode_example_main.py new file mode 100644 index 000000000..17d08bd4f --- /dev/null +++ b/Programs/Source/multinode_example_main.py @@ -0,0 +1,34 @@ +import random + +n_nodes_per_party = int(program.args[1]) +n_threads_per_node = int(program.args[2]) +n_ops_per_thread = int(program.args[3]) + +n_ops_per_node = n_threads_per_node * n_ops_per_thread +n_ops = n_nodes_per_party * n_ops_per_node +data = Array.create_from(sint(regint.inc(n_ops))) + +listen_for_clients(15000) + +ready = regint.Array(n_nodes_per_party) + +@for_range(n_nodes_per_party) +def _(i): + ready[accept_client_connection(15000)] = 1 + +runtime_error_if(sum(ready) != n_nodes_per_party, 'connection problems') + +@for_range(n_nodes_per_party) +def _(i): + data.get_vector(base=i * n_ops_per_node, + size=n_ops_per_node).write_fully_to_socket(i) + +@for_range(n_nodes_per_party) +def _(i): + data.assign_vector(sint.read_from_socket(i, size=n_ops_per_node), + base=i * n_ops_per_node) + +for i in range(10): + index = random.randrange(n_ops) + value = data[index].reveal() + runtime_error_if(value != index ** 2, '%s != %s', value, index ** 2) diff --git a/Programs/Source/multinode_example_worker.py b/Programs/Source/multinode_example_worker.py new file mode 100644 index 000000000..2d271c5cc --- /dev/null +++ b/Programs/Source/multinode_example_worker.py @@ -0,0 +1,21 @@ +n_threads = int(program.args[1]) +n_ops_per_thread = int(program.args[2]) +worker_id = int(program.args[3]) + +if len(program.args) > 4: + host = program.args[4] +else: + host = 'localhost' + +n_ops = n_threads * n_ops_per_thread +data = sint.Array(n_ops) + +main = init_client_connection(host, 15000, worker_id) + +data.read_from_socket(main) + +@for_range_opt_multithread(n_threads, n_ops) +def _(i): + data[i] = data[i] ** 2 + +data.write_to_socket(main) diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index 9d99c37da..b5ef155d4 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -1,4 +1,5 @@ # sint: secret integers +# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint # you can assign public numbers to sint @@ -14,6 +15,7 @@ def test(actual, expected): # private inputs are read from Player-Data/Input-P-0 # or from standard input if using command-line option -I +# see https://mp-spdz.readthedocs.io/en/latest/io.html for more options for i in 0, 1: print_ln('got %s from player %s', sint.get_input_from(i).reveal(), i) @@ -62,6 +64,7 @@ test(a[99], 99 * 98) # test(a, 99) # sfix: fixed-point numbers +# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix # set the precision after the dot and in total diff --git a/Protocols/AtlasShare.h b/Protocols/AtlasShare.h index 99afc33b4..bea233a59 100644 --- a/Protocols/AtlasShare.h +++ b/Protocols/AtlasShare.h @@ -35,6 +35,11 @@ class AtlasShare : public ShamirShare typedef GC::AtlasSecret bit_type; #endif + static string alt() + { + return ""; + } + AtlasShare() { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index c40308c59..f462862a4 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -7,6 +7,7 @@ #define PROTOCOLS_FAKEPROTOCOL_H_ #include "Replicated.h" +#include "SecureShuffle.h" #include "Math/Z2k.h" #include "Processor/Instruction.h" #include "Processor/TruncPrTuple.h" @@ -17,6 +18,8 @@ template class FakeShuffle { public: + typedef ShuffleStore store_type; + FakeShuffle(SubProcessor&) { } @@ -27,9 +30,9 @@ class FakeShuffle apply(a, n, unit_size, output_base, input_base, 0, 0); } - size_t generate(size_t) + size_t generate(size_t, store_type& store) { - return 0; + return store.add(); } void apply(vector& a, size_t n, int unit_size, size_t output_base, @@ -49,10 +52,6 @@ class FakeShuffle } } - void del(size_t) - { - } - void inverse_permutation(vector&, size_t, size_t, size_t) { } @@ -280,6 +279,19 @@ class FakeProtocol : public ProtocolBase } } } + else if (tag == string("EQZ\0", 4)) + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + assert(args[i] == 6); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + res = processor.get_S()[args[i + 3] + j] == 0; + } + } + } else if (tag == "Trun") { for (size_t i = 0; i < args.size(); i += args[i]) diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index a73142b75..7a8d424fd 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -35,6 +35,7 @@ class FakeShare : public T, public ShareInterface static const bool dishonest_majority = false; static const bool malicious = false; static const bool is_real = false; + static const bool variable_players = false; static string type_short() { diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 0d8d2f695..2073eac26 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -33,7 +33,7 @@ class Hemi : public T::BasicProtocol ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); - void matmulsm(SubProcessor& processor, CheckVector& source, + void matmulsm(SubProcessor& processor, MemoryPart& source, const Instruction& instruction, int a, int b); void conv2ds(SubProcessor& processor, const Instruction& instruction); }; diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 2b847530e..b232bc42d 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -33,7 +33,7 @@ typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, } template -void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, +void Hemi::matmulsm(SubProcessor& processor, MemoryPart& source, const Instruction& instruction, int a, int b) { if (HemiOptions::singleton.plain_matmul @@ -61,16 +61,16 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, for (int i = 0; i < dim[0]; i++) for (int k = 0; k < dim[1]; k++) { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ii = Proc->get_Ci().at(dim[3] + i); + auto kk = Proc->get_Ci().at(dim[4] + k).get(); + auto ii = Proc->get_Ci().at(dim[3] + i).get(); A.entries.v.push_back(source.at(a + ii * dim[7] + kk)); } for (int k = 0; k < dim[1]; k++) for (int j = 0; j < dim[2]; j++) { - auto jj = Proc->get_Ci().at(dim[6] + j); - auto ll = Proc->get_Ci().at(dim[5] + k); + auto jj = Proc->get_Ci().at(dim[6] + j).get(); + auto ll = Proc->get_Ci().at(dim[5] + k).get(); B.entries.v.push_back(source.at(b + ll * dim[8] + jj)); } diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index ddf7e186f..0a54f94df 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -34,6 +34,11 @@ class HemiShare : public SemiShare static const bool local_mul = true; static true_type triple_matmul; + static string alt() + { + return "Temi"; + } + HemiShare() { } diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 53dc7e557..988f5ea6a 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -250,8 +250,11 @@ TreeSum::~TreeSum() template void TreeSum::run(vector& values, const Player& P) { - start(values, P); - finish(values, P); + if (not values.empty()) + { + start(values, P); + finish(values, P); + } } template @@ -300,9 +303,11 @@ void TreeSum::add_openings(vector& values, const Player& P, P.wait_receive(sender, oss[j]); MC.player_timers[sender].stop(); MC.timers[SUM].start(); + T tmp = values.at(0); for (unsigned int i=0; i::Check(const Player& P) for (auto& os : bundle) if (&os != &bundle.mine) delta += os.get(); - if (not delta.is_zero()) + if (delta != 0) throw mac_fail(); } } @@ -194,8 +194,6 @@ void MAC_Check_::Check(const Player& P) typename U::mac_type a,gami,temp; typename U::mac_type::Scalar h; vector tau(P.num_players()); - a.assign_zero(); - gami.assign_zero(); for (int i=0; i::Check(const Player& P) //cerr << "\tFinal Check" << endl; typename U::mac_type t; - t.assign_zero(); for (int i=0; i::TripleShuffleSacrifice(int B, int C) : { } +template +TripleShuffleSacrifice::TripleShuffleSacrifice(DataFieldType type) : + ShuffleSacrifice(BaseMachine::bucket_size(type)) +{ +} + template void TripleShuffleSacrifice::triple_sacrifice(vector>& triples, vector>& check_triples, Player& P, diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h index 79da7480f..08fb0c3b4 100644 --- a/Protocols/NoShare.h +++ b/Protocols/NoShare.h @@ -44,6 +44,10 @@ class NoShare : public ShareInterface // default private output facility (using input tuples) typedef ::PrivateOutput PrivateOutput; + // indicate whether protocol allows dishonest majority and variable players + static const bool dishonest_majority = true; + static const bool variable_players = true; + // description used for debugging output static string type_string() { @@ -187,4 +191,11 @@ class NoShare : public ShareInterface } }; +template +inline ostream& operator<<(ostream& o, NoShare) +{ + throw runtime_error("no output"); + return o; +} + #endif /* PROTOCOLS_NOSHARE_H_ */ diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h index ec80a48e4..94d86c9c5 100644 --- a/Protocols/Rep3Shuffler.h +++ b/Protocols/Rep3Shuffler.h @@ -6,12 +6,17 @@ #ifndef PROTOCOLS_REP3SHUFFLER_H_ #define PROTOCOLS_REP3SHUFFLER_H_ +#include "SecureShuffle.h" + template class Rep3Shuffler { - SubProcessor& proc; +public: + typedef array, 2> shuffle_type; + typedef ShuffleStore store_type; - vector, 2>> shuffles; +private: + SubProcessor& proc; public: Rep3Shuffler(vector& a, size_t n, int unit_size, size_t output_base, @@ -19,15 +24,13 @@ class Rep3Shuffler Rep3Shuffler(SubProcessor& proc); - int generate(int n_shuffle); + int generate(int n_shuffle, store_type& store); void apply(vector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, int handle, bool reverse); + size_t input_base, shuffle_type& shuffle, bool reverse); void inverse_permutation(vector& stack, size_t n, size_t output_base, size_t input_base); - - void del(int handle); }; #endif /* PROTOCOLS_REP3SHUFFLER_H_ */ diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp index 19de8c4f9..f3a29c84d 100644 --- a/Protocols/Rep3Shuffler.hpp +++ b/Protocols/Rep3Shuffler.hpp @@ -13,9 +13,10 @@ Rep3Shuffler::Rep3Shuffler(vector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : proc(proc) { - apply(a, n, unit_size, output_base, input_base, generate(n / unit_size), + store_type store; + int handle = generate(n / unit_size, store); + apply(a, n, unit_size, output_base, input_base, store.get(handle), false); - shuffles.pop_back(); } template @@ -25,10 +26,10 @@ Rep3Shuffler::Rep3Shuffler(SubProcessor& proc) : } template -int Rep3Shuffler::generate(int n_shuffle) +int Rep3Shuffler::generate(int n_shuffle, store_type& store) { - shuffles.push_back({}); - auto& shuffle = shuffles.back(); + int res = store.add(); + auto& shuffle = store.get(res); for (int i = 0; i < 2; i++) { auto& perm = shuffle[i]; @@ -40,19 +41,22 @@ int Rep3Shuffler::generate(int n_shuffle) swap(perm[k], perm[k + j]); } } - return shuffles.size() - 1; + return res; } template void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, - size_t output_base, size_t input_base, int handle, bool reverse) + size_t output_base, size_t input_base, shuffle_type& shuffle, + bool reverse) { assert(proc.P.num_players() == 3); assert(not T::malicious); assert(not T::dishonest_majority); assert(n % unit_size == 0); - auto& shuffle = shuffles.at(handle); + if (shuffle.empty()) + throw runtime_error("shuffle has been deleted"); + vector to_shuffle; for (size_t i = 0; i < n; i++) to_shuffle.push_back(a[input_base + i]); @@ -115,12 +119,6 @@ void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, a[output_base + i] = to_shuffle[i]; } -template -void Rep3Shuffler::del(int handle) -{ - shuffles.at(handle) = {}; -} - template void Rep3Shuffler::inverse_permutation(vector&, size_t, size_t, size_t) { diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index cb4db86c7..d1e383ffb 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -41,6 +41,7 @@ class Rep4Share : public RepShare typedef GC::Rep4Secret bit_type; static const bool malicious = true; + static const bool variable_players = false; static string type_short() { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 17c0bacb2..4fb5a6317 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -15,6 +15,7 @@ using namespace std; #include "Tools/random.h" #include "Tools/PointerVector.h" #include "Networking/Player.h" +#include "Processor/Memory.h" template class SubProcessor; template class ReplicatedMC; @@ -68,8 +69,6 @@ class ProtocolBase ProtocolBase(); virtual ~ProtocolBase(); - void muls(const vector& reg, SubProcessor& proc, typename T::MAC_Check& MC, - int size); void mulrs(const vector& reg, SubProcessor& proc); void multiply(vector& products, vector>& multiplicands, @@ -111,7 +110,7 @@ class ProtocolBase virtual void randoms_inst(vector&, const Instruction&); template - void matmulsm(SubProcessor & proc, CheckVector& source, + void matmulsm(SubProcessor & proc, MemoryPart& source, const Instruction& instruction, int a, int b) { proc.matmulsm(source, instruction, a, b); } diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 1f0cc5e2f..dc6324451 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -78,14 +78,6 @@ ProtocolBase::~ProtocolBase() #endif } -template -void ProtocolBase::muls(const vector& reg, - SubProcessor& proc, typename T::MAC_Check& MC, int size) -{ - (void)MC; - proc.muls(reg, size); -} - template void ProtocolBase::mulrs(const vector& reg, SubProcessor& proc) diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index c1c265ea8..5601db457 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -9,18 +9,42 @@ #include using namespace std; +#include "Tools/Lock.h" + template class SubProcessor; +template +class ShuffleStore +{ + typedef T shuffle_type; + + deque shuffles; + + Lock store_lock; + + void lock(); + void unlock(); + +public: + int add(); + shuffle_type& get(int handle); + void del(int handle); +}; + template class SecureShuffle { +public: + typedef vector>> shuffle_type; + typedef ShuffleStore store_type; + +private: SubProcessor& proc; vector to_shuffle; vector> config; vector tmp; int unit_size; - vector>>> shuffles; size_t n_shuffle; bool exact; @@ -62,7 +86,7 @@ class SecureShuffle SecureShuffle(SubProcessor& proc); - int generate(int n_shuffle); + int generate(int n_shuffle, store_type& store); /** * @@ -73,12 +97,12 @@ class SecureShuffle * would result in [3,4,1,2] * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) - * @param handle The integer identifying the preconfigured waksman network (shuffle) to use. Such a handle can be obtained from calling + * @param shuffle The preconfigured waksman network (shuffle) to use * @param reverse Boolean indicating whether to apply the inverse of the permutation * @see SecureShuffle::generate for obtaining a shuffle handle */ void apply(vector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, int handle, bool reverse); + size_t input_base, shuffle_type& shuffle, bool reverse); /** * Calculate the secret inverse permutation of stack given secret permutation. @@ -94,8 +118,6 @@ class SecureShuffle * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) */ void inverse_permutation(vector& stack, size_t n, size_t output_base, size_t input_base); - - void del(int handle); }; #endif /* PROTOCOLS_SECURESHUFFLE_H_ */ diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 752798b2a..f41c3f970 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -12,6 +12,45 @@ #include #include +template +void ShuffleStore::lock() +{ + store_lock.lock(); +} + +template +void ShuffleStore::unlock() +{ + store_lock.unlock(); +} + +template +int ShuffleStore::add() +{ + lock(); + int res = shuffles.size(); + shuffles.push_back({}); + unlock(); + return res; +} + +template +typename ShuffleStore::shuffle_type& ShuffleStore::get(int handle) +{ + lock(); + auto& res = shuffles.at(handle); + unlock(); + return res; +} + +template +void ShuffleStore::del(int handle) +{ + lock(); + shuffles.at(handle) = {}; + unlock(); +} + template SecureShuffle::SecureShuffle(SubProcessor& proc) : proc(proc), unit_size(0), n_shuffle(0), exact(false) @@ -33,13 +72,12 @@ SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, template void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, int handle, bool reverse) + size_t input_base, shuffle_type& shuffle, bool reverse) { this->unit_size = unit_size; pre(a, n, input_base); - auto& shuffle = shuffles.at(handle); assert(shuffle.size() == proc.protocol.get_relevant_players().size()); if (reverse) @@ -134,12 +172,6 @@ void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t ou post(stack, n, output_base); } -template -void SecureShuffle::del(int handle) -{ - shuffles.at(handle).clear(); -} - template void SecureShuffle::pre(vector& a, size_t n, size_t input_base) { @@ -230,11 +262,10 @@ void SecureShuffle::player_round(int config_player) { } template -int SecureShuffle::generate(int n_shuffle) +int SecureShuffle::generate(int n_shuffle, store_type& store) { - int res = shuffles.size(); - shuffles.push_back({}); - auto& shuffle = shuffles.back(); + int res = store.add(); + auto& shuffle = store.get(res); for (auto i: proc.protocol.get_relevant_players()) { vector perm; diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index efc7e45f8..12966cfd8 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -65,6 +65,11 @@ class ShamirShare : public T, public ShareInterface return "Shamir " + T::type_string(); } + static string alt() + { + return "ATLAS"; + } + static int threshold(int) { return ShamirMachine::s().threshold; diff --git a/Protocols/Share.h b/Protocols/Share.h index 4a05a049b..cfab66d12 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -85,12 +85,10 @@ class Share_ : public ShareInterface void assign(const char* buffer) { a.assign(buffer); mac.assign(buffer + T::size()); } void assign_zero() - { a.assign_zero(); - mac.assign_zero(); - } + { *this = {}; } void assign(const open_type& aa, int my_num, const typename V::Scalar& alphai); - Share_() { assign_zero(); } + Share_() {} template Share_(const Share_& S) { assign(S); } Share_(const open_type& aa, int my_num, const typename V::Scalar& alphai) diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index b45828e35..187d021d2 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -54,6 +54,8 @@ class ShareInterface static string type_short() { throw runtime_error("shorthand undefined"); } + static string alt() { return ""; } + static bool real_shares(const Player&) { return true; } template diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index 56ae0f0a3..2b24cd533 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -62,6 +62,7 @@ class TripleShuffleSacrifice : public ShuffleSacrifice public: TripleShuffleSacrifice(); TripleShuffleSacrifice(int B, int C); + TripleShuffleSacrifice(DataFieldType type); void triple_sacrifice(vector>& triples, vector>& check_triples, Player& P, diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 7fadd8195..5ca976e40 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -528,7 +528,7 @@ void EdabitShuffleSacrifice::edabit_sacrifice_buckets(vector>& to_c sum <<= n_shift; if (single != sum) { - cout << hex << single << " vs " << (sum << n_shift) << "/" << sum + cout << hex << single << " vs " << sum << endl; throw Offline_Check_Error("edabit shuffle bucket opening"); } diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h index 049881ffe..7d1a0cd05 100644 --- a/Protocols/TemiShare.h +++ b/Protocols/TemiShare.h @@ -31,6 +31,11 @@ class TemiShare : public HemiShare static const bool needs_ot = false; static const bool local_mul = false; + static string alt() + { + return ""; + } + TemiShare() { } diff --git a/README.md b/README.md index f7b88ace7..04434a545 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,9 @@ differences between the various protocols. #### Frequently Asked Questions [The documentation](https://mp-spdz.readthedocs.io/en/latest) contains -sections on a number of frequently asked topics as well as information -on how to solve common issues. +section on a number of [frequently asked +topics](https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html) +as well as information on how to solve common issues. #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) @@ -354,7 +355,7 @@ There are three ways of running computation: all necessary input and certificate files via SSH. ``` - Scripts/compile-run.py -HOSTS -E mascot -- [...] + Scripts/compile-run.py -H HOSTS -E mascot -- [...] ``` `HOSTS` has to be a text file in the following format: @@ -370,6 +371,8 @@ There are three ways of running computation: user. Otherwise (`//` after the hostname it will be relative to the root directory. + It is assumed that the SSH login is possible without password. + Even with the integrated execution it is important to keep in mind that there are two different phases, the compilation and the run-time phase. Any secret data is only available in the second phase, when the @@ -423,8 +426,8 @@ directly. For fixed-point computation this is done via The length is communicated to the virtual machines and automatically used if supported. By default, they support bit lengths 64, 72, and -128. If another length is required, use `MOD = -DRING_SIZE=` in `CONFIG.mine`. +128 (the latter except for SPDZ2k). If another length is required, use +`MOD = -DRING_SIZE=` in `CONFIG.mine`. #### Binary circuits diff --git a/Scripts/compile-emulate.py b/Scripts/compile-emulate.py index 5d5fbd2f7..b6a53a969 100755 --- a/Scripts/compile-emulate.py +++ b/Scripts/compile-emulate.py @@ -6,7 +6,7 @@ from Compiler.compilerLib import Compiler -compiler = Compiler() +compiler = Compiler(split_args=True) compiler.prep_compile(build=False) compiler.execute = True compiler.options.execute = 'emulate' diff --git a/Scripts/compile-run.py b/Scripts/compile-run.py index 70aff1bf6..0d031717d 100755 --- a/Scripts/compile-run.py +++ b/Scripts/compile-run.py @@ -6,18 +6,14 @@ from Compiler.compilerLib import Compiler -try: - split = sys.argv.index('--') -except ValueError: - split = len(sys.argv) - -compiler_args = sys.argv[1:split] -runtime_args = sys.argv[split + 1:] -compiler = Compiler(execute=True, custom_args=compiler_args) +compiler = Compiler( + execute=True, split_args=True, + usage="usage: %prog [options] [-E] protocol filename [args] " + "[-- [run-time args]]") compiler.prep_compile() prog = compiler.compile_file() if prog.options.hostfile: - compiler.remote_execution(runtime_args) + compiler.remote_execution() else: - compiler.local_execution(runtime_args) + compiler.local_execution() diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index 098f90b77..22ed3b212 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -26,6 +26,8 @@ def process(tapename, res, regs): regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size) tapes = Program.read_tapes(sys.argv[1]) +n_threads = Program.read_n_threads(sys.argv[1]) +domain_size = Program.read_domain_size(sys.argv[1]) or 8 process(next(tapes), res, regs) @@ -45,9 +47,11 @@ def output(data): pass total = 0 -for x in res, regs, thread_regs: +for x in res, regs: total += sum(x.values()) +thread_total = sum(thread_regs.values()) + print ('Memory:') output(regout(res)) @@ -58,5 +62,9 @@ def output(data): print ('Registers in other threads:') output(regout(thread_regs)) -print ('The program requires at the very least %f GB of RAM per party.' % \ - (total * 8e-9)) +min = 1 * domain_size +max = 3 * domain_size + +print ('The program requires at least an estimated %f-%f GB of RAM per party.' + % (min * (total + thread_total) * 1e-9, + max * ((total + (n_threads - 1) * thread_total) * 1e-9))) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 64d24d300..93352e71a 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -16,6 +16,16 @@ gdb_screen() screen -S :$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; gdb $prog -ex \"run $*\"" } +valgrind_screen() +{ + prog=$1 + shift + IFS= + name=${*/-/} + IFS=' ' + screen -S :$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; valgrind $prog $*" +} + lldb_screen() { prog=$1 diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index b43c2f3f1..dd70b5f11 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash echo SECURE = -DINSECURE >> CONFIG.mine +touch -r CONFIG CONFIG.mine touch ECDSA/Fake-ECDSA.cpp make -j4 ecdsa Fake-ECDSA.x diff --git a/Tools/DiskVector.cpp b/Tools/DiskVector.cpp new file mode 100644 index 000000000..7bfb3b32b --- /dev/null +++ b/Tools/DiskVector.cpp @@ -0,0 +1,44 @@ +/* + * DiskVectorBase.cpp + * + */ + +#include "DiskVector.h" +#include "Processor/OnlineOptions.h" + +#include + +void sigbus_handler(int) +{ + cerr << "Received SIGBUS. This is most likely due to missing space " + << "for the on-disk memory on " + << OnlineOptions::singleton.disk_memory << "." << endl; + exit(1); +} + +void DiskVectorBase::init(size_t byte_size) +{ + if (file.is_open()) + throw runtime_error("resizing of disk memory not implemented"); + else + { + path = boost::filesystem::unique_path( + (boost::filesystem::path(OnlineOptions::singleton.disk_memory) + / std::string("%%%%-%%%%-%%%%-%%%%")).native()); + + std::ofstream f(path.native()); + f.close(); + } + + if (truncate(path.native().c_str(), byte_size)) + throw std::runtime_error( + "cannot allocate " + std::to_string(byte_size) + " bytes in " + + path.native() + ": " + strerror(errno)); + + file.open(path, boost::iostreams::mapped_file::readwrite, byte_size); + assert(file.size() == byte_size); + + boost::filesystem::remove(path); + + signal(SIGBUS, sigbus_handler); +} diff --git a/Tools/DiskVector.h b/Tools/DiskVector.h new file mode 100644 index 000000000..30d30376f --- /dev/null +++ b/Tools/DiskVector.h @@ -0,0 +1,84 @@ +/* + * DiskVector.h + * + */ + +#ifndef TOOLS_DISKVECTOR_H_ +#define TOOLS_DISKVECTOR_H_ + +#include +#include + +class DiskVectorBase +{ +protected: + boost::iostreams::mapped_file file; + boost::filesystem::path path; + +public: + ~DiskVectorBase() + { + boost::filesystem::remove(path); + } + + void init(size_t byte_size); +}; + +template +class DiskVector : DiskVectorBase +{ + size_t size_; + T* data_; + +public: + DiskVector() : size_(0), data_(0) + { + } + + size_t size() const + { + return size_; + } + + void resize(size_t new_size) + { + auto byte_size = new_size * sizeof(T); + init(byte_size); + size_ = new_size; + data_ = (T*) file.data(); + } + + T* data() + { + return data_; + } + + const T* data() const + { + return data_; + } + + T& operator[](size_t index) + { + return data_[index]; + } + + const T& operator[](size_t index) const + { + return data_[index]; + } + + T& at(size_t index) + { + assert(index <= size_); + return data_[index]; + } + + const T& at(size_t index) const + { + assert(index <= size_); + return data_[index]; + } +}; + +#endif /* TOOLS_DISKVECTOR_H_ */ diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index 2d38ec90f..a0e9ab4f5 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -106,3 +106,10 @@ prep_setup_error::prep_setup_error(const string& error, int nplayers, + to_string(nplayers) + fake_opts + "'?") { } + +insufficient_shares::insufficient_shares(int expected, int actual, exception& e) : + runtime_error( + "expected " + to_string(expected) + " shares but only got " + + to_string(actual) + " (" + e.what() + ")") +{ +} diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index b62f01c32..469f544fc 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -296,4 +296,10 @@ class prep_setup_error : public setup_error prep_setup_error(const string& error, int nplayers, const string& fake_opts); }; +class insufficient_shares : public runtime_error +{ +public: + insufficient_shares(int expected, int actual, exception& e); +}; + #endif diff --git a/Tools/parse.h b/Tools/parse.h index c4b973dd7..a9f76f2bb 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -50,4 +50,12 @@ inline void get_vector(int m, vector& start, istream& s) start[i] = be32toh(start[i]); } +inline void get_string(string& res, istream& s) +{ + unsigned size = get_int(s); + char buf[size]; + s.read(buf, size); + res.assign(buf, size); +} + #endif /* TOOLS_PARSE_H_ */ diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index f943a9545..bf33330b4 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -32,6 +32,15 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) "-t", // Flag token. "--threshold" // Flag token. ); + opt.add( + "100000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Size of circuit batches (default: 100000)", // Help description. + "-b", // Flag token. + "--batch-size" // Flag token. + ); auto& online_opts = OnlineOptions::singleton; online_opts = {opt, argc, argv, false}; NetworkOptionsWithNumber network_opts(opt, argc, argv, 2, false); @@ -41,6 +50,7 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) int threshold; bool continuous = not opt.get("-O")->isSet; opt.get("-t")->getInt(threshold); + opt.get("-b")->getInt(online_opts.batch_size); progname = online_opts.progname; GC::ThreadMasterBase* master; diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3f54d385e..e4dc7a8e4 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -14,7 +14,7 @@ steps: - script: | bash -c "sudo apt-get update && sudo apt-get install libsodium-dev libntl-dev python3-gmpy2 python3-networkx" - script: | - make boost libote + make setup - script: echo USE_NTL=1 >> CONFIG.mine - script: diff --git a/deps/libOTe b/deps/libOTe index a10272337..3c1a60029 160000 --- a/deps/libOTe +++ b/deps/libOTe @@ -1 +1 @@ -Subproject commit a10272337b742814673d83e114c03d6904b652e2 +Subproject commit 3c1a60029f097ff794b8a88ce1215cff7eb76628 diff --git a/doc/compilation.rst b/doc/compilation.rst index 9d9cfe875..417bc7922 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -3,7 +3,7 @@ Compilation Process The easiest way of using MP-SPDZ is using ``compile.py`` as described below. If you would like to run compilation directly from -Python, see :ref:`Direct Compilation in Python`. +Python, see :ref:`direct-compilation`. After putting your code in ``Program/Source/.[mpc|py]``, run the compiler from the root directory as follows @@ -36,14 +36,17 @@ The following options influence the computation domain: .. cmdoption:: -P --prime= - Specify a concrete prime modulus for computation. This can be used + Use bit decomposition by `Nishide and Ohta + `_ with a concrete + prime modulus for non-linear computation. This can be used together with :option:`-F`, in which case *integer length* has to be at most the prime length minus two. The security implications of overflows in the secrets do not go beyond incorrect results. You can use prime order domains without specifying this option. Using this option involves algorithms for non-linear computation which are generally more expensive but allow for integer lengths - that are close to the bit length of the prime. + that are close to the bit length of the prime. See + :ref:`nonlinear` for more details .. cmdoption:: -R --ring= @@ -135,6 +138,8 @@ computation: to the run time. +.. _direct-compilation: + Direct Compilation in Python ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You may prefer to not have an entirely static `.mpc` file to compile, diff --git a/doc/index.rst b/doc/index.rst index 40216b39b..47b481a33 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -8,7 +8,7 @@ more information on multi-party computation. If you're new to MP-SPDZ, consider the following: -1. `Quickstart tutorial `_ +1. `Quickstart tutorial `_ 2. :ref:`Machine learning quickstart ` 3. `Implemented protocols `_ 4. :ref:`troubleshooting` @@ -21,6 +21,8 @@ If you're new to MP-SPDZ, consider the following: readme compilation Compiler + utils + journey instructions low-level ml-quickstart @@ -28,6 +30,7 @@ If you're new to MP-SPDZ, consider the following: networking io client-interface + multinode non-linear preprocessing lowest-level diff --git a/doc/instructions.rst b/doc/instructions.rst index e1be21d63..a050ebb4b 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -134,6 +134,8 @@ size. This is to make sure that even when using memory with run-time addresses, the virtual machine is aware of the memory sizes. +.. _instructions: + Instructions ------------ diff --git a/doc/io.rst b/doc/io.rst index 053e93bf9..fd092993d 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -89,14 +89,19 @@ functions are available for :py:class:`~Compiler.types.sfix` and See also :ref:`client ref` below. +.. _persistence: + Secret Shares ~~~~~~~~~~~~~ :py:func:`Compiler.types.sint.read_from_file` and :py:func:`Compiler.types.sint.write_to_file` allow reading and writing secret shares to and from files. These instructions use -``Persistence/Transactions-P.data``. The format depends on -the protocol with the following principles. +``Persistence/Transactions-P.data``. This files use the same +header as :ref:`preprocessing files `. The format for the +shares data depends on the protocol and is created by the ``output`` +member function of the relevant :ref:`share type `. It +follows the following principles: - One share follows the other without metadata. - If there is a MAC, it comes after the share. diff --git a/doc/journey.rst b/doc/journey.rst new file mode 100644 index 000000000..64ba67be4 --- /dev/null +++ b/doc/journey.rst @@ -0,0 +1,228 @@ +The Journey of a Program +======================== + +In this section, we will demonstrate the life cycle of a high-level +program in MP-SPDZ. + +Consider the following program:: + + print_ln('%s', sint(123).reveal()) + +It entails three steps: creating a constant secret sharing, revealing +it to cleartext, and outputting it. The compilation will not execute +any of this. Instead, it will create a description in a format +specific to MP-SPDZ. For example, +:py:func:`~Compiler.types.sint.reveal` triggers a call to the +constructor of :py:obj:`~Compiler.instructions.asm_open`, which will +add an object thereof to a list of instructions. + +Run the following to retrieve the human-readable representation of the +computation:: + + echo print_ln('%s', sint(123).reveal()) > Programs/Source/journey.py + ./compile.py -a debug journey + +This will create :file:`debug-journey-0` with the +following content:: + + # journey-0--0 + ldsi s0, 123 # 0 + asm_open 3, True, c0, s0 # 1 + print_reg_plain c0 # 2 + print_char 10 # 3 + use 0, 7, 1 # 4 + # journey-0-memory-usage-1 + ldmc c0, 8191 # 5 + gldmc cg0, 8191 # 6 + ldmint ci0, 8191 # 7 + ldms s0, 8191 # 8 + gldms sg0, 8191 # 9 + active True # 10 + +The first block corresponds mostly to the program whereas the second +block is more generic. More specifically:: + + ldsi s0, 123 # 0 + +:py:class:`~Compiler.instructions.ldsi` loads constant values to +secret registers, in this case 123 to the register :obj:`s0`. + +.. code:: + + asm_open 3, True, c0, s0 # 1 + +:py:class:`~Compiler.instructions.asm_open` reveals values in secret +registers to be stored in cleartext registers, in this case the +content of :obj:`s0` to :obj:`c0`. The :obj:`True` argument triggers a +correctness check in protocols where it is available, and the 3 +indicates the number of arguments to follow as the instruction is +batchable, that is, it can execute any number of + +.. code:: + + print_reg_plain c0 # 2 + +:py:class:`~Compiler.instructions.print_reg_plain` outputs constant +values to the console or a file, in this case the register :obj:`c0`. + +.. code:: + + print_char 10 # 3 + +:py:class:`~Compiler.instructions.print_char` outputs a character to +the console or a file, in this case the ASCII code for a new line. + +.. code:: + + use 0, 7, 1 # 4 + +:py:class:`~Compiler.instructions.use` indicates the usage of +preprocessing information or similar. This allows the virtual machine +to account for resources before actually executing the program. This +particular call indicates 1 opening (7) of sint (0). You can see the +codes in :py:obj:`data_type` and :py:obj:`field_types` at the beginning +of :download:`Compiler/program.py <../Compiler/program.py>`. + +.. code:: + + ldmc c0, 8191 # 5 + gldmc cg0, 8191 # 6 + ldmint ci0, 8191 # 7 + ldms s0, 8191 # 8 + gldms sg0, 8191 # 9 + +These instructions read memory cells to registers, for example +:py:class:`~Compiler.instructions.ldms`. In this context, the purpose +is to indicate the memory usage. The addresses are all 8191 because +8192 is the default size for user memory given in +:file:`Compiler/config.py`. If you use +:py:class:`~Compiler.types.Array` or similar data-structures, these +numbers will increase accordingly. + +.. code:: + + active True # 10 + +:py:class:`~Compiler.instructions.active` indicates whether the +program is compatible with active security. + +The compilation above also creates +:file:`Programs/Bytecode/journey-0.bc`, the hexdump output of which +looks as follows:: + + 00000000 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 7b |...............{| + 00000010 00 00 00 00 00 00 00 a5 00 00 00 03 00 00 00 01 |................| + 00000020 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 b3 |................| + 00000030 00 00 00 00 00 00 00 00 00 00 00 b4 00 00 00 0a |................| + 00000040 00 00 00 00 00 00 00 17 00 00 00 00 00 00 00 07 |................| + 00000050 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 03 |................| + 00000060 00 00 00 00 00 00 00 00 00 00 1f ff 00 00 00 00 |................| + 00000070 00 00 01 03 00 00 00 00 00 00 00 00 00 00 1f ff |................| + 00000080 00 00 00 00 00 00 00 ca 00 00 00 00 00 00 00 00 |................| + 00000090 00 00 1f ff 00 00 00 00 00 00 00 04 00 00 00 00 |................| + 000000a0 00 00 00 00 00 00 1f ff 00 00 00 00 00 00 01 04 |................| + 000000b0 00 00 00 00 00 00 00 00 00 00 1f ff 00 00 00 00 |................| + 000000c0 00 00 00 e9 00 00 00 01 |........| + 000000c8 + +It consist of the instructions codes and the arguments in big-endian +order. For example, 0x2 is the code for :py:obj:`lsdi`, 0xa5 is the +code for :py:obj:`asm_open`, 0xb3 is the code for +:py:obj:`print_reg_plain`, etc. You can also spot repeated occurrences +of ``1f ff``, which is the hexadecimal representation of 8191. + +Finally, the compilation creates +:file:`Programs/Schedules/journey.sch`, which is a text file:: + + 1 + 1 + journey-0:11 + 1 0 + 0 + ./compile.py journey + lgp:0 + opts: + sec:40 + +The first two lines indicate the number of threads and bytecode files, +followed by the names of bytecode files (and the number of +instructions in each one). The fourth and fifth line are legacy, and +the sixth indicates the compilation command line. The remaining lines +indicate further options used during compilation. + + +Execution +--------- + +.. default-domain:: cpp + +We will now walk through what happens when executing the program above +with Rep3 modulo :math:`2^{64}`. The main function in +:download:`Machines/replicated-ring-party.cpp +<../Machines/replicated-ring-party.cpp>` indirectly calls +:func:`Machine::run` in :download:`Processor/Machine.hpp +<../Processor/Machine.hpp>` with :class:`sint` being +``Rep3Share2<64>``. Then, the following happens: + +1. :file:`Programs/Schedules/journey.sch` is parsed :func:`load_schedule`. +2. :file:`Programs/Bytecode/journey-0.bc` is parsed in + :func:`Machine::load_program` where + :func:`Program::parse`. This creates an internal representation of the + code in :var:`Program::p` where an :class:`Instruction` object + describes every instruction. +3. :func:`Machine::prepare` creates a computation thread + using :func:`pthread_create`, which runs :func:`thread_info::Main_Func` in :download:`Processor/Online-Thread.hpp + <../Processor/Online-Thread.hpp>`. +4. :func:`Machine::run` calls :func:`Machine::run_tape`, which signals the thread which code to run. +5. The computation thread waits for a signal in + :func:`thread_info::Sub_Main_Func`. Once received, it + calls :func:`Program::execute` in + :download:`Processor/Instruction.hpp <../Processor/Instruction.hpp>`. +6. :func:`Program::execute` runs the main loop over the + instructions. There is a switch statement acting on the instruction + codes. +7. ``LDSI`` is defined in ``ARITHMETIC_INSTRUCTIONS`` in + :download:`Processor/instructions.h + <../Processor/instructions.h>`. It calls :func:`sint::constant`, + which is defined in :download:`Protocols/Rep3Share.h + <../Protocols/Rep3Share.h>` for ``Rep3Share2<64>``. This is in + turn calls :func:`Replicated::assign` in + :download:`Protocols/Replicated.h <../Protocols/Replicated.h>`, + which creates a constant replicated secret sharing of 123, that is + (123, 0) for party 0, (0, 123) for party 1, and (0, 0) for party 2. +8. ``OPEN`` is defined in another switch statement in + :func:`Instruction::execute` in + :download:`Processor/Instruction.hpp + <../Processor/Instruction.hpp>`, where :func:`SubProcessor::POpen` + in :download:`Processor/Processor.hpp <../Processor/Processor.hpp>` + is called. This is turn uses the four-step interface of + :class:`MAC_Check_Base` with an instance of + :class:`ReplicatedMC`. The communication happens in + :func:`ReplicatedMC::exchange`, and the reconstruction (summation) + happens :func:`ReplicatedMC::finalize`, both in + :download:`Protocols/ReplicatedMC.hpp + <../Protocols/ReplicatedMC.hpp>`. The remaining functions mainly + handle copying data and serialization. +9. ``PRINTREGPLAIN`` is also defined in the second switch statement, + where :func:`Instruction::print` in + :download:`Processor/Instruction.hpp + <../Processor/Instruction.hpp>` is called. This function uses + :class:`SwitchableOutput`, which is used to output to console, to + file, or not at all depending on the settings. +10. ``PRINTCHR`` is defined in ``REGINT_INSTRUCTIONS`` in + :download:`Processor/instructions.h + <../Processor/instructions.h>`, which means that it's called via a + switch statement in :func:`Instruction::execute_regint` in + :download:`Processor/Instruction.cpp + <../Processor/Instruction.cpp>`. It also uses + :class:`SwitchableOutput`. +11. The remaining instructions are executed similarly but not do have + a relevant effect. +12. When :func:`Program::execute` is done, control returns to + :func:`thread_info::Sub_Main_Func`, which signals + completion to the main thread. +13. After receiving the signal, :func:`Machine::run` + completes and outputs the various statistics and exits + diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index d873256bd..5dc2fc865 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -443,7 +443,7 @@ and used in MP-SPDZ:: This outputs the accuracy of the network. You can use :py:func:`~Compiler.ml.Optimizer.eval` instead of :py:func:`~Compiler.ml.Optimizer.reveal_correctness` to retrieve -probability distributions or top guessess (the latter with ``top=True``) +probability distributions or top guesses (the latter with ``top=True``) for any sample data. diff --git a/doc/multinode.rst b/doc/multinode.rst new file mode 100644 index 000000000..69cb47be0 --- /dev/null +++ b/doc/multinode.rst @@ -0,0 +1,65 @@ +Multinode Computation Example +============================= + +Multinode computation refers to the possibility of distributing a +every party across several nodes. MP-SPDZ uses the client interface +for communication between nodes. This means that you have to run +``Scripts/setup-clients.sh`` and distribute the certificates to run it +across several machines. + +In the following, we will explain the example in +:download:`../Programs/Source/multinode_example_main.py` and +:download:`../Programs/Source/multinode_example_worker.py`. + +First, the one main node per party, listens and accepts connections +from the worker nodes in the same logical party:: + + listen_for_clients(15000) + + ready = regint.Array(n_nodes_per_party) + + @for_range(n_nodes) + def _(i): + ready[accept_client_connection(15000)] = 1 + + runtime_error_if(sum(ready) != n_nodes_party, 'connection problems') + +Maintaining :py:obj:`ready` helps spot errors but isn't strictly +necessary. Meanwhile, the workers connect to main node:: + + main = init_client_connection(host, 15000, worker_id) + +Once the connection is established, the main node distributes the data +among the workers:: + + @for_range(n_nodes_per_party) + def _(i): + data.get_vector(base=i * n_ops_per_node, + size=n_ops_per_node).write_fully_to_socket(i) + +This sends a different chunk :py:obj:`data` to every node. The workers +then receive it and execute the computation (squaring every number in +the example), and send the result back:: + + @for_range_opt_multithread(n_threads, n_ops) + def _(i): + data[i] = data[i] ** 2 + + data.write_to_socket(main) + +Finally, the main node receives the result:: + + @for_range(n_nodes_per_party) + def _(i): + data.assign_vector(sint.read_from_socket(i, size=n_ops_per_node), + base=i * n_ops_per_node) + +You can execute example with three parties, four worker nodes per +party, five threads per worker node, and 1000 operations per thread as +follows:: + + for i in $(seq 0 3); do + Scripts/compile-run.py ring multinode_example_worker 5 1000 $i localhost & true + done + + Scripts/compile-run.py ring multinode_example_main 4 5 1000 diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 4687cc637..ec2a53c23 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -25,6 +25,9 @@ Unknown prime modulus parameter. It has the downside that there is implicit enforcement of the cleartext range. + If you want to use this approach with a given prime, do *not* + specify the prime during compilation but during execution. + Known prime modulus `Damgård et al. `_ have proposed non-linear computation that involves an exact prime @@ -41,6 +44,8 @@ Known prime modulus :math:`k`-bit number is indistinguishable from a random number modulo :math:`p` if the latter is close enough to :math:`2^k`. + This approach is used if you specify a prime during compilation. + Power-of-two modulus In the context of non-linear computation, there are two important differences to prime modulus setting: diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 21500c455..74a686eb6 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -29,6 +29,8 @@ using ``-b``, others mandate a batch size, which can be as large as a million. +.. _prep-files: + Separate preprocessing ====================== @@ -102,11 +104,22 @@ Modulo a prime with :math:`R` being the smallest power of :math:`2^{64}` larger than the prime. For example, :math:`R = 2^{128}` for a 128-bit prime. Furthermore, the values are stored in the smallest number of 8-byte - blocks necessary, all in little-endian order. + blocks necessary, all in little-endian order. As an example, + consider the default 128-bit prime + :math:`p = 170141183460469231731687303715885907969`. The Montgomery + representation of :math:`x` is :math:`xR \bmod p`. For :math:`x = + 1`, this is 170141183460469231731687303715882303487 or + 0x7fffffffffffffffffffffffffe47fff in hexadecimal. Using + to little-endian, ``hexdump -C`` would output the following:: + + ff 7f e4 ff ff ff ff ff ff ff ff ff ff ff ff 7f Modulo a power of two: Values are stored in the smallest number of 8-byte blocks necessary, - all in little-endian order. + all in little-endian order, so 1 with a modulus of :math:`2^{64}` + would result in the following ``hexdump -C`` output:: + + 01 00 00 00 00 00 00 00 :math:`GF(2^n)` Values are stored in blocks according to the storage size above, diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 508936a96..7517312c7 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -15,6 +15,10 @@ memory usage for some malicious protocols with ``-B 5``. Furthermore, every computation thread requires separate resources, so consider reducing the number of threads with :py:func:`~Compiler.library.for_range_multithreads` and similar. +Lastly, you can use ``--disk-memory `` to use disk space instead +of RAM for large programs. +Use ``Scripts/memory-usage.py `` to get an estimate +of the memory usage of a specific program. List indices must be integers or slices @@ -83,6 +87,12 @@ If the condition is secret, for example, :py:obj:`x` is an branching would reveal the secret. For the same reason, :py:func:`~Compiler.library.print_ln_if` doesn't work on secret values. +Use ``bit_and`` etc. for more elaborate conditions:: + + @if_(a.bit_and(b.bit_or(c))) + def _(): + ... + Incorrect results when using :py:class:`~Compiler.types.sfix` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -97,7 +107,7 @@ to change the precision. Variable results when using :py:class:`~Compiler.types.sfix` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This is caused the usage of probablistic rounding, which is used to +This is caused the usage of probabilistic rounding, which is used to restore the representation after a multiplication. See `Catrina and Saxena `_ for details. You can switch to deterministic rounding by calling ``sfix.round_nearest = True``. @@ -116,6 +126,16 @@ protection (:py:func:`~Compiler.program.Program.protect_memory`) around specific memory accesses. +High number of rounds or slow WAN execution +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can increase the optimization budget using ``--budget`` during +compilation. The budget controls the trade-off between compilation +speed/memory usage and communication rounds during execution. The +default is 1000, but 100,000 might give better results while still +keeping compilation manageable. + + Odd timings ~~~~~~~~~~~ diff --git a/doc/utils.rst b/doc/utils.rst new file mode 100644 index 000000000..019bf2889 --- /dev/null +++ b/doc/utils.rst @@ -0,0 +1,23 @@ +Bytecode Utilities +================== + + +Memory usage +------------ + +``Scripts/memory-usage.py `` gives you an estimate +of the minimum RAM usage per party. The range is relatively large due +to fact the bytecode is independent of the secret sharing. + + +Human-readable bytecode/circuit representation +---------------------------------------------- + +``Scripts/decompile.py `` produces human-readable +version of the bytecode in ``Programs/Bytecode``. The filename format +is ``Programs/Bytecode/-.asm``. For +example, after compiling and decompiling the tutorial, you will find +``Programs/Bytecode/tutorial-0.asm``. You can find the full list of +tape names in the third line of ``Programs/Schedule/tutorial.sch``. +See :ref:`this section ` for an explanation of +instruction names.