Skip to content

Support multibyte characters #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 27, 2019
68 changes: 45 additions & 23 deletions kaldiio/matio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from kaldiio.utils import MultiFileDescriptor
from kaldiio.utils import open_like_kaldi
from kaldiio.utils import open_or_fd
from kaldiio.utils import py2_default_encoding
from kaldiio.utils import seekable
from kaldiio.wavio import read_wav
from kaldiio.wavio import read_wav_scipy
Expand Down Expand Up @@ -55,7 +56,7 @@ def load_scp(fname, endian='<', separator=None, as_bytes=False,
raise ValueError(
'Invalid line is found:\n> {}'.format(line))
token, arkname = seps
loader[token] = arkname.strip()
loader[token] = arkname.rstrip()
return loader
else:
return SegmentsExtractor(fname, separator=separator,
Expand Down Expand Up @@ -86,7 +87,7 @@ def load_scp_sequential(fname, endian='<', separator=None, as_bytes=False,
raise ValueError(
'Invalid line is found:\n> {}'.format(line))
token, arkname = seps
arkname = arkname.strip()
arkname = arkname.rstrip()

ark, offset, slices = _parse_arkpath(arkname)

Expand Down Expand Up @@ -140,9 +141,9 @@ def __init__(self, fname, segments=None, separator=None):

self.segments = segments
self._segments_dict = {}
with open(self.segments) as f:
with open_or_fd(self.segments, 'r') as f:
for l in f:
sps = l.strip().split(separator)
sps = l.rstrip().split(separator)
if len(sps) != 4:
raise RuntimeError('Format is invalid: {}'.format(l))
uttid, recodeid, st, et = sps
Expand Down Expand Up @@ -223,7 +224,7 @@ def _parse_arkpath(ark_name):
>>> _parse_arkpath('cat "fo:o.ark" |')
'cat "fo:o.ark" |', None, None
"""
if ark_name.strip()[-1] == '|' or ark_name.strip()[0] == '|':
if ark_name.rstrip()[-1] == '|' or ark_name.rstrip()[0] == '|':
# Something like: "| cat foo" or "cat bar|" shouldn't be parsed
return ark_name, None, None

Expand Down Expand Up @@ -299,17 +300,19 @@ def read_token(fd):
fd (file):
"""
token = []
# Keep the loop until finding ' ' or end of char
while True:
char = fd.read(1)
if isinstance(char, binary_type):
char = char.decode()
if char == ' ' or char == '':
c = fd.read(1)
if c == b' ' or c == b'':
break
else:
token.append(char)
token.append(c)
if len(token) == 0: # End of file
return None
return ''.join(token)
if PY3:
decoded = b''.join(token).decode()
else:
decoded = b''.join(token).decode(py2_default_encoding)
return decoded


def read_kaldi(fd, endian='<', return_size=False, use_scipy_wav=False):
Expand Down Expand Up @@ -476,11 +479,14 @@ def read_ascii_mat(fd, return_size=False):

# Find '[' char
while True:
b = fd.read(1)
try:
char = fd.read(1).decode()
except UnicodeDecodeError as e:
raise UnicodeDecodeError(
str(e) + '\nFile format is wrong?')
if PY3:
char = b.decode()
else:
char = b.decode(py2_default_encoding)
except UnicodeDecodeError:
raise ValueError('File format is wrong?')
size += 1
if char == ' ' or char == os.linesep:
continue
Expand All @@ -495,11 +501,17 @@ def read_ascii_mat(fd, return_size=False):
# Read data
ndmin = 1
while True:
char = fd.read(1).decode()
if PY3:
char = fd.read(1).decode()
else:
char = fd.read(1).decode(py2_default_encoding)
size += 1
if hasparent:
if char == ']':
char = fd.read(1).decode()
if PY3:
char = fd.read(1).decode()
else:
char = fd.read(1).decode(py2_default_encoding)
size += 1
assert char == os.linesep or char == ''
break
Expand Down Expand Up @@ -582,8 +594,12 @@ def save_ark(ark, array_dict, scp=None, append=False, text=False,
offset = 0
size = 0
for key in array_dict:
fd.write((key + ' ').encode())
size += len(key) + 1
if PY3:
encode_key = (key + ' ').encode()
else:
encode_key = (key + ' ').encode(py2_default_encoding)
fd.write(encode_key)
size += len(encode_key)
pos_list.append(size)
if as_bytes:
byte = bytes(array_dict[key])
Expand All @@ -606,7 +622,7 @@ def save_ark(ark, array_dict, scp=None, append=False, text=False,
name = ark if isinstance(ark, string_types) else ark.name
with open_or_fd(scp, mode) as fd:
for key, position in zip(array_dict, pos_list):
fd.write(key + ' ' + name + ':' +
fd.write(key + u' ' + name + ':' +
str(position + offset) + os.linesep)


Expand Down Expand Up @@ -730,7 +746,10 @@ def write_array_ascii(fd, array, digit='.12g'):
size += 3
for i in row:
string = format(i, digit)
fd.write(string.encode())
if PY3:
fd.write(string.encode())
else:
fd.write(string.encode(py2_default_encoding))
fd.write(b' ')
size += len(string) + 1
fd.write(b']\n')
Expand All @@ -740,7 +759,10 @@ def write_array_ascii(fd, array, digit='.12g'):
size += 1
for i in array:
string = format(i, digit)
fd.write(string.encode())
if PY3:
fd.write(string.encode())
else:
fd.write(string.encode(py2_default_encoding))
fd.write(b' ')
size += len(string) + 1
fd.write(b']\n')
Expand Down
12 changes: 10 additions & 2 deletions kaldiio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
else:
from collections import MutableMapping

py2_default_encoding = 'utf-8'


if PY3:
def my_popen(cmd, mode='r', buffering=-1):
Expand Down Expand Up @@ -147,14 +149,20 @@ def open_like_kaldi(name, mode='r'):
else:
return _stdstream_wrap(sys.stdout)
else:
return open(name, mode)
return io.open(name, mode)


@contextmanager
def open_or_fd(fname, mode):
# If fname is a file name
if isinstance(fname, string_types):
f = open(fname, mode)
if PY3:
f = open(fname, mode)
else:
if 'b' not in mode:
f = io.open(fname, mode, encoding=py2_default_encoding)
else:
f = io.open(fname, mode)
# If fname is a file descriptor
else:
if PY3 and 'b' in mode and isinstance(fname, TextIOBase):
Expand Down
40 changes: 21 additions & 19 deletions tests/test_mat_ark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# coding: utf-8
import glob
import io
import os

import numpy as np
Expand Down Expand Up @@ -30,7 +32,7 @@ def test_write_read(tmpdir, shape1, shape2, endian, dtype):

a = np.random.rand(*shape1).astype(dtype)
b = np.random.rand(*shape2).astype(dtype)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath, endian=endian)

Expand All @@ -39,7 +41,7 @@ def test_write_read(tmpdir, shape1, shape2, endian, dtype):
d5 = {k: v
for k, v in kaldiio.load_scp(path.join('b.scp').strpath,
endian=endian).items()}
with open(path.join('a.ark').strpath, 'rb') as fd:
with io.open(path.join('a.ark').strpath, 'rb') as fd:
d6 = {k: v for k, v in
kaldiio.load_ark(fd, endian=endian)}
_compare(d2, origin)
Expand All @@ -54,15 +56,15 @@ def test_write_read_multiark(tmpdir, endian, dtype):

a = np.random.rand(1000, 120).astype(dtype)
b = np.random.rand(10, 120).astype(dtype)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}

kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath, endian=endian)

c = np.random.rand(1000, 120).astype(dtype)
d = np.random.rand(10, 120).astype(dtype)
origin.update({'c': c, 'd': d})
with open(path.join('b.scp').strpath, 'a') as f:
origin.update({u'c': c, u'd': d})
with io.open(path.join('b.scp').strpath, 'a', encoding='utf-8') as f:
kaldiio.save_ark(path.join('b.ark').strpath, origin,
scp=f, endian=endian)

Expand All @@ -78,7 +80,7 @@ def test_write_read_sequential(tmpdir, endian):

a = np.random.rand(1000, 120).astype(np.float32)
b = np.random.rand(10, 120).astype(np.float32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath, endian=endian)

Expand All @@ -94,15 +96,15 @@ def test_write_read_multiark_sequential(tmpdir, endian):

a = np.random.rand(1000, 120).astype(np.float32)
b = np.random.rand(10, 120).astype(np.float32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}

kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath, endian=endian)

c = np.random.rand(1000, 120).astype(np.float32)
d = np.random.rand(10, 120).astype(np.float32)
origin.update({'c': c, 'd': d})
with open(path.join('b.scp').strpath, 'a') as f:
origin.update({u'c': c, u'd': d})
with io.open(path.join('b.scp').strpath, 'a', encoding='utf-8') as f:
kaldiio.save_ark(path.join('b.ark').strpath, origin,
scp=f, endian=endian)

Expand All @@ -116,7 +118,7 @@ def test_write_read_ascii(tmpdir):
path = tmpdir.mkdir('test')
a = np.random.rand(10, 10).astype(np.float32)
b = np.random.rand(5, 35).astype(np.float32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('a.scp').strpath, text=True)
d2 = {k: v for k, v in kaldiio.load_ark(path.join('a.ark').strpath)}
Expand All @@ -132,7 +134,7 @@ def test_write_read_int32_vector(tmpdir, endian):

a = np.random.randint(1, 128, 10, dtype=np.int32)
b = np.random.randint(1, 128, 10, dtype=np.int32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath,
endian=endian)
Expand All @@ -142,7 +144,7 @@ def test_write_read_int32_vector(tmpdir, endian):
d5 = {k: v
for k, v in kaldiio.load_scp(path.join('b.scp').strpath,
endian=endian).items()}
with open(path.join('a.ark').strpath, 'rb') as fd:
with io.open(path.join('a.ark').strpath, 'rb') as fd:
d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)}
_compare(d2, origin)
_compare(d5, origin)
Expand All @@ -154,15 +156,15 @@ def test_write_read_int32_vector_ascii(tmpdir):

a = np.random.randint(1, 128, 10, dtype=np.int32)
b = np.random.randint(1, 128, 10, dtype=np.int32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath,
text=True)

d2 = {k: v for k, v in kaldiio.load_ark(path.join('a.ark').strpath)}
d5 = {k: v
for k, v in kaldiio.load_scp(path.join('b.scp').strpath).items()}
with open(path.join('a.ark').strpath, 'r') as fd:
with io.open(path.join('a.ark').strpath, 'rb') as fd:
d6 = {k: v for k, v in kaldiio.load_ark(fd)}
_compare_allclose(d2, origin)
_compare_allclose(d5, origin)
Expand Down Expand Up @@ -190,7 +192,7 @@ def test_write_read_compress(tmpdir, compression_method, endian):

a = np.random.rand(1000, 120).astype(np.float32)
b = np.random.rand(10, 120).astype(np.float32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath,
compression_method=compression_method,
Expand All @@ -201,7 +203,7 @@ def test_write_read_compress(tmpdir, compression_method, endian):
d5 = {k: v
for k, v in kaldiio.load_scp(path.join('b.scp').strpath,
endian=endian).items()}
with open(path.join('a.ark').strpath, 'rb') as fd:
with io.open(path.join('a.ark').strpath, 'rb') as fd:
d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)}
_compare_allclose(d2, origin, atol=1e-1)
_compare_allclose(d5, origin, atol=1e-1)
Expand All @@ -213,13 +215,13 @@ def test_append_mode(tmpdir):

a = np.random.rand(1000, 120).astype(np.float32)
b = np.random.rand(10, 120).astype(np.float32)
origin = {'a': a, 'b': b}
origin = {u'Ï,é,à': a, u'あいうえお': b}
kaldiio.save_ark(path.join('a.ark').strpath, origin,
scp=path.join('b.scp').strpath)

kaldiio.save_ark(path.join('a2.ark').strpath, {'a': a},
kaldiio.save_ark(path.join('a2.ark').strpath, {u'Ï,é,à': a},
scp=path.join('b2.scp').strpath, append=True)
kaldiio.save_ark(path.join('a2.ark').strpath, {'b': b},
kaldiio.save_ark(path.join('a2.ark').strpath, {u'あいうえお': b},
scp=path.join('b2.scp').strpath, append=True)
d1 = {k: v for k, v in kaldiio.load_ark(path.join('a.ark').strpath)}
d2 = {k: v
Expand Down