diff --git a/lib/money-tree.rb b/lib/money-tree.rb index fa21a70..60aca03 100644 --- a/lib/money-tree.rb +++ b/lib/money-tree.rb @@ -8,5 +8,17 @@ require "money-tree/node" module MoneyTree - + DEFAULT_CURVE_NAME = 'secp256k1' + DEFAULT_CURVE_ID = 714 + @curve_name = DEFAULT_CURVE_NAME + @curve_id = DEFAULT_CURVE_ID + + def self.setCurve(name, id) + @curve_name = name || DEFAULT_CURVE_NAME + @curve_id = id || DEFAULT_CURVE_ID + end + + def self.getCurve() + { name: @curve_name, id: @curve_id } + end end diff --git a/lib/money-tree/key.rb b/lib/money-tree/key.rb index b8a122f..924dc9e 100644 --- a/lib/money-tree/key.rb +++ b/lib/money-tree/key.rb @@ -38,7 +38,7 @@ class PrivateKey < Key def initialize(opts = {}) @options = opts - @ec_key = PKey::EC.new GROUP_NAME + @ec_key = PKey::EC.new MoneyTree.getCurve()[:name] || GROUP_NAME if @options[:key] @raw_key = @options[:key] @key = parse_raw_key @@ -176,7 +176,7 @@ def initialize(p_key, opts = {}) @key = @raw_key = to_hex else @raw_key = p_key - @group = PKey::EC::Group.new GROUP_NAME + @group = PKey::EC::Group.new MoneyTree.getCurve()[:name] || GROUP_NAME @key = parse_raw_key end diff --git a/lib/money-tree/version.rb b/lib/money-tree/version.rb index 612eaa6..17b3dda 100644 --- a/lib/money-tree/version.rb +++ b/lib/money-tree/version.rb @@ -1,3 +1,3 @@ module MoneyTree - VERSION = "0.10.0" + VERSION = "0.10.1" end diff --git a/lib/openssl_extensions.rb b/lib/openssl_extensions.rb index 44b7f5e..3adf250 100644 --- a/lib/openssl_extensions.rb +++ b/lib/openssl_extensions.rb @@ -20,12 +20,12 @@ module OpenSSLExtensions attach_function :EC_POINT_point2hex, [:pointer, :pointer, :int, :pointer], :string attach_function :EC_POINT_hex2point, [:pointer, :string, :pointer, :pointer], :pointer attach_function :EC_POINT_new, [:pointer], :pointer - + def self.add(point_0, point_1) validate_points(point_0, point_1) - eckey = EC_KEY_new_by_curve_name(NID_secp256k1) + eckey = EC_KEY_new_by_curve_name(MoneyTree.getCurve()[:id] || NID_secp256k1) group = EC_KEY_get0_group(eckey) - + point_0_hex = point_0.to_bn.to_s(16) point_0_pt = EC_POINT_hex2point(group, point_0_hex, nil, nil) point_1_hex = point_1.to_bn.to_s(16) @@ -52,9 +52,9 @@ def self.add(point_0, point_1) def self.validate_points(*points) points.each do |point| if !point.is_a?(OpenSSL::PKey::EC::Point) - raise ArgumentError, "point must be an OpenSSL::PKey::EC::Point object" + raise ArgumentError, "point must be an OpenSSL::PKey::EC::Point object" elsif point.infinity? - raise ArgumentError, "point must not be infinity" + raise ArgumentError, "point must not be infinity" end end end @@ -64,10 +64,10 @@ def self.validate_points(*points) class OpenSSL::PKey::EC::Point include MoneyTree::OpenSSLExtensions - + def add(point) sum_point_hex = MoneyTree::OpenSSLExtensions.add(self, point) self.class.new group, OpenSSL::BN.new(sum_point_hex, 16) end - + end