Skip to content

Commit f849918

Browse files
author
Wayne Zhao
committed
Renamed files, added guinea pig and mouse
1 parent bf28801 commit f849918

6 files changed

+842
-72
lines changed

Cochlea.py

+54-46
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ class Cochlea(FilterBank):
1212
'''
1313
Model of Cochlea (a subclass of FilterBank)
1414
'''
15-
def __init__(self, types=None, species=None, CF0=20, length=20, l_factor=3.8, xs=None, cfs=None, rho=1.000, Ap=None, bp=None, Bu=None, gain_const=None, peak_magndb=None, Bpeak=None, fpeak=None, phiaccum=None, Nbeta=None, Nf=None, Qerb=None, ERBbeta=None, ERBf=None, Qn=None, Qn2=None, BWndBbeta=None, BWndBf=None, BWn2dBbeta=None, BWn2dBf=None, Sbeta=None, Sf=None, n=10, n2=3, betas=None, freqs=None):
15+
def __init__(self, species=None, type=None, CF0=20, l_factor=3.8, length=20, xs=None, cfs=None, rho=1.000, Ap=None, bp=None, Bu=None, gain_const=None, peak_magndb=None, Bpeak=None, fpeak=None, phiaccum=None, Nbeta=None, Nf=None, Qerb=None, ERBbeta=None, ERBf=None, Qn=None, Qn2=None, BWndBbeta=None, BWndBf=None, BWn2dBbeta=None, BWn2dBf=None, Sbeta=None, Sf=None, n=10, n2=3, betas=None, freqs=None):
1616
'''
17-
Initializes Cochlea. Most arguments are the same as for `FilterBank` object
17+
Initializes Cochlea. Most arguments are the same as for `FilterBank` object.
1818
1919
Attributes:
20+
species: fast way to initialize Cochlea for certain species
21+
num_filters: number of filters to model cochlea with. Default is 4.
2022
CF0: characteristic frequency (kHz) of base of cochlea
2123
length: length of cochlea (mm)
2224
l_factor: constant factor for cochlear model (mm)
@@ -28,30 +30,40 @@ def __init__(self, types=None, species=None, CF0=20, length=20, l_factor=3.8, xs
2830
'''
2931
if species is not None:
3032
CF0, l_factor, length = self._given_species(species)
31-
3233
self.cochlea_length = length
34+
self.cf = (lambda x: CF0*np.exp(-x/l_factor))
3335

34-
args = {'Ap':Ap, 'bp':bp, 'Bu':Bu, 'gain_const':gain_const, 'peak_magndb':peak_magndb,
35-
'Bpeak':Bpeak, 'fpeak':fpeak, 'phiaccum':phiaccum, 'Nbeta':Nbeta, 'Nf':Nf, 'Qerb':Qerb, 'ERBbeta':ERBbeta, 'ERBf':ERBf, 'Qn':Qn, 'Qn2':Qn2, 'BWndBbeta':BWndBbeta, 'BWndBbeta':BWndBf, 'BWn2dBbeta':BWn2dBbeta, 'BWn2dBf':BWn2dBf, 'Sbeta':Sbeta, 'Sf':Sf, 'n':n, 'n2':n2, 'betas':betas, 'freqs':freqs}
36-
num_filters = 1
37-
for v in args.values():
38-
if np.ndim(v) >= 1:
39-
num_filters = max(num_filters, len(v))
36+
type = 'P' if type is None else type
4037

41-
self.cf = (lambda x: CF0*np.exp(-x/l_factor))
42-
if cfs is None:
43-
if xs is None:
44-
xs = np.linspace(0, length, num_filters)
45-
cfs = [self.cf(x) for x in xs]
38+
if species is not None:
39+
xs = np.linspace(0, length, 4) # let user set?
40+
cfs = self.cf(xs)
41+
Ap = 0.3768 * np.exp(-0.1366 * cfs) # seems improbable
42+
bp = [1., 1., 1., 1.]
43+
Bu = 3.714 * np.exp(0.03123 * cfs) # same here
44+
args = {'Ap':Ap, 'bp':bp, 'Bu':Bu, 'cf':cfs}
4645
else:
47-
if xs is not None:
48-
raise Exception('Please provide either only a list of all locations along cochlea or a list of all characteristic frequencies')
49-
xs = [-np.log(c/CF0)*l_factor for c in cfs]
50-
self.xs = xs
51-
types = 'P' if types is None else types
46+
args = {'Ap':Ap, 'bp':bp, 'Bu':Bu, 'gain_const':gain_const, 'peak_magndb':peak_magndb,
47+
'Bpeak':Bpeak, 'fpeak':fpeak, 'phiaccum':phiaccum, 'Nbeta':Nbeta, 'Nf':Nf, 'Qerb':Qerb, 'ERBbeta':ERBbeta, 'ERBf':ERBf, 'Qn':Qn, 'Qn2':Qn2, 'BWndBbeta':BWndBbeta, 'BWndBbeta':BWndBf, 'BWn2dBbeta':BWn2dBbeta, 'BWn2dBf':BWn2dBf, 'Sbeta':Sbeta, 'Sf':Sf, 'n':n, 'n2':n2}
48+
num_filters = 1
49+
for v in args.values():
50+
if np.ndim(v) >= 1:
51+
num_filters = max(num_filters, len(v))
5252

53-
args['cf'] = cfs
54-
super().__init__(topology='parallel', type=types, **args)
53+
if cfs is None:
54+
if xs is None:
55+
xs = np.linspace(0, length, num_filters)
56+
cfs = [self.cf(x) for x in xs]
57+
else:
58+
if xs is not None:
59+
raise Exception('Please provide either only a list of all locations along cochlea or a list of all characteristic frequencies')
60+
xs = [-np.log(cf/CF0)*l_factor for cf in cfs]
61+
62+
self.xs = xs
63+
args['cf'] = cfs
64+
args['betas'] = [betas for _ in range(num_filters)]
65+
args['freqs'] = [freqs for _ in range(num_filters)]
66+
super().__init__(topology='series', type=type, **args)
5567

5668
apexmost_filter = self.filters[-1]
5769

@@ -64,8 +76,6 @@ def __init__(self, types=None, species=None, CF0=20, length=20, l_factor=3.8, xs
6476
self.bp_fun = (lambda x: np.exp(np.interp(x, self.xs, np.log(np.array(bp)))))
6577
self.Bu_fun = (lambda x: np.exp(np.interp(x, self.xs, np.log(np.array(Bu)))))
6678

67-
# beta = w/2pi/CF(x)
68-
6979
p = 1j*bp_apex - Ap_apex
7080
# k and Z both normalized to not depend on l
7181
self.wavenumber = (lambda beta: (beta/l_factor) * 2 * Bu_apex * (1j*beta + Ap_apex) / ((1j*beta - p)*(1j*beta - p.conjugate())))
@@ -76,18 +86,16 @@ def __init__(self, types=None, species=None, CF0=20, length=20, l_factor=3.8, xs
7686
self.Z = self.impedance
7787

7888
@classmethod
79-
def five_param(cls, types=None, aAp=None, bAp=None, bp=None, aBu=None, bBu=None, gain_const=None, peak_magndb=None, CF0=20, length=20, xs=None, rho=1.000):
89+
def five_param(cls, type=None, aAp=None, bAp=None, bp=None, aBu=None, bBu=None, gain_const=None, peak_magndb=None, CF0=20, l_factor=3.8, length=20, xs=None, rho=1.000, betas=None, freqs=None):
8090
'''
8191
Five parameter parameterization of Cochlea from (Alkhairy 2019)
8292
'''
8393
if xs is None:
8494
xs = np.linspace(0, length, 4)
85-
cf = (lambda x: CF0*np.exp(-x/3.8))
86-
# print('cfx', cf(0))
87-
# print(bAp, np.exp(bAp*cf(0)))
95+
cf = (lambda x: CF0*np.exp(-x/l_factor))
8896
Ap_func = (lambda x: aAp*np.exp(bAp*cf(x)))
8997
Bu_func = (lambda x: aBu*np.exp(bBu*cf(x)))
90-
cochlea = cls(types=types, Ap=[Ap_func(x) for x in xs], bp=bp, Bu=[Bu_func(x) for x in xs], gain_const=gain_const, peak_magndb=peak_magndb, CF0=CF0, length=length, xs=xs, rho=rho, species=None)
98+
cochlea = cls(type=type, Ap=[Ap_func(x) for x in xs], bp=bp, Bu=[Bu_func(x) for x in xs], gain_const=gain_const, peak_magndb=peak_magndb, CF0=CF0, length=length, xs=xs, rho=rho, species=None, betas=betas, freqs=freqs)
9199
cochlea.Ap_fun = Ap_func
92100
cochlea.Bu_fun = Bu_func
93101
return cochlea
@@ -111,23 +119,24 @@ def filter_at_location(self, x_coord, gain_const=None, peak_magndb=None, type='P
111119
return Filter(type=type, Ap=self.Ap_fun(x_coord), bp=self.bp_fun(x_coord), Bu=self.Bu_fun(x_coord), gain_const=gain_const, peak_magndb=peak_magndb, cf=self.cf(x_coord))
112120

113121
def _given_species(self, species):
114-
if species is not None:
115-
if species == 'chinchilla':
116-
CF0 = 28.131
117-
l_factor = 3.6649
118-
length = 35
119-
elif species == 'human':
120-
CF0 = 20.823
121-
l_factor = 7.2382
122-
length = 20
123-
# elif species == 'guinea pig' or species == 'guineapig':
124-
# elif species == 'mouse'
125-
xs = np.linspace(0, length, 5)[1:]
126-
Ap = [0.3768 * np.exp(-0.1366 * CF0 * np.exp(-x/l_factor)) for x in xs] # this seems improbable
127-
bp = [1., 1., 1., 1.]
128-
Bu = [3.714 * np.exp(0.03123 * CF0 * np.exp(-x/l_factor)) for x in xs] # same here
122+
if species == 'chinchilla': # Muller et al HR 2010
123+
CF0 = 28.131
124+
l_factor = 3.6649
125+
length = 20
126+
elif species == 'human':
127+
CF0 = 20.823
128+
l_factor = 7.2382
129+
length = 35
130+
elif species == 'guinea pig' or species == 'guineapig': # Tsuji and Liberman 1997 J. Comp. Neurol. 381:188-202
131+
CF0 = 54.732
132+
l_factor = 3.3178
133+
length = 20
134+
elif species == 'mouse':
135+
CF0 = 71.130
136+
l_factor = 1.8566
137+
length = 5.13
129138
else:
130-
l_factor = 3.8
139+
raise Exception(f'"{species}" is an unsupported species')
131140

132141
return (CF0, l_factor, length)
133142

@@ -191,7 +200,7 @@ def plot_impedance(self, betas=None, setting='realimag', custom_title='Normalize
191200
show: `True` if plot is to be shown, `False` otherwise. Default is `True`.
192201
phase_in_rad: Show phase in radians if True or in cycles otherwise
193202
'''
194-
# plot Z and normalized Z
203+
# plot Z and normalized Z?
195204
if betas is None:
196205
betas = np.linspace(0.01, self.bp_apex*1.5, 10000)
197206
Zdata = self.Z_norm(betas)
@@ -235,7 +244,6 @@ def signal_response_heatmap(self, signal, len_xs=20, custom_title='Signal Heatma
235244
'''
236245
sigs = []
237246
cfs = []
238-
# signal.plot()
239247
for x in np.linspace(0.025, self.cochlea_length, len_xs):
240248
fil = self.filter_at_location(x)
241249
cfs += [round(self.cf(x), 2)]

Filter.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, ir=None, tf=None, coeffs=None, roots=None, type=None, Ap=None
6767
has_roots = (roots is not None)
6868
has_params = any(param is not None for param in [Ap, bp, Bu])
6969
has_chars = any(characteristic is not None for characteristic in [Bpeak, fpeak, phiaccum, Nbeta, Nf, Qerb, ERBbeta, ERBf, Qn, Qn2, BWndBbeta, BWndBf, BWn2dBbeta, BWn2dBf, Sbeta, Sf])
70-
if sum([has_coeffs, has_roots, has_tf, has_params, has_chars]) != 1:
70+
if sum([has_coeffs, has_roots, has_tf, has_ir, has_params, has_chars]) != 1:
7171
raise Exception('Exactly one filter representation should be used')
7272

7373
self.in_terms_of_normalized = True
@@ -86,7 +86,13 @@ def __init__(self, ir=None, tf=None, coeffs=None, roots=None, type=None, Ap=None
8686
if Sf is not None: Sbeta = Sf * self.cf**2
8787
if Nf is not None: Nbeta = Nf * self.cf
8888

89-
if betas is None: betas = np.geomspace(0.01, 10, 10000) # is there a more adaptive way to pick betas if it is not provided
89+
if betas is None:
90+
maxbeta = 10
91+
if bp is not None:
92+
maxbeta = 3*bp
93+
if Bpeak is not None:
94+
maxbeta = 3*Bpeak
95+
betas = np.linspace(0.01, maxbeta, 10000)
9096

9197
# eventually refactor by making the __init__ of the three classes deal with the logic (which it actually already mostly does)
9298
if has_tf:
@@ -510,7 +516,7 @@ def pole_zero_plot(self, custom_title=None, show=True):
510516
ax.axhline(y=0, color='k', ls=':')
511517
ax.axvline(x=0, color='k', ls=':')
512518
ax.scatter([z.real for z in zeros], [z.imag for z in zeros], marker='o', facecolors='none', edgecolors='tab:orange')
513-
ax.scatter([z.real for z in poles], [z.imag for z in poles], marker='x', edgecolors='tab:blue')
519+
ax.scatter([z.real for z in poles], [z.imag for z in poles], marker='x', facecolors='tab:blue')
514520
ax.set_xlabel('Re(z)')
515521
ax.set_ylabel('Im(z)')
516522
plt.axis('equal')

FilterType.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,19 @@ def __init__(self, uid, tf=None, ir=None, betas=None):
7878
if betas is None:
7979
betas = np.geomspace(0.01, 10, 10000)
8080
self.init_with_tf = (tf is not None)
81-
chars = helpers.computedfiltercharacteristics(tfunc=tf, betas=betas)
81+
if self.init_with_tf:
82+
chars = helpers.computedfiltercharacteristics(tfunc=tf, betas=betas)
83+
else:
84+
# much faster than using approx_tf, even if it's not as accurate
85+
siglen = len(betas)*2
86+
samprate = betas[-1]*2
87+
timestamps = np.arange(siglen)/samprate
88+
irarr = [ir(t) for t in timestamps]
89+
tfapprox = scipy.fft.rfft(irarr)
90+
origbetas = scipy.fft.rfftfreq(siglen, d=1/samprate)
91+
chars = helpers.computedfiltercharacteristics(tfunc=(lambda f: np.interp(f, origbetas, tfapprox)), betas=betas)
92+
93+
# np.interp
8294

8395
if tf is not None:
8496
if ir is not None:

Signal.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
2-
import scipy as sp
32
import scipy.fft
3+
import scipy.signal
4+
import scipy.integrate
5+
import scipy.io
46
import matplotlib.pyplot as plt
57

68
def tolerate(x, eps=1e-10):
@@ -48,15 +50,15 @@ def __init__(self, mode='t', data=[0 for _ in range(9)], fs=1, evenlen=True):
4850
# having self.length as length of self.mode_t keeps a bit more info (literally)
4951
if self.mode in ['t', 'ttilde']:
5052
self.mode_t = data
51-
self.mode_f = sp.fft.rfft(data)
53+
self.mode_f = scipy.fft.rfft(data)
5254
self.length = len(data)
5355
else:
5456
if self.mode == 'w':
5557
self.mode_f = data / 2 / np.pi
5658
else:
5759
self.mode_f = data
5860
self.length = 2*len(data)-1-int(evenlen) # -2 if evenlen is True and -1 if evenlen is False
59-
self.mode_t = sp.fft.irfft(self.mode_f, self.length)
61+
self.mode_t = scipy.fft.irfft(self.mode_f, self.length)
6062

6163
self.func = (lambda t: self.at_time(t, tolerance=0)) # automatically in time domain, from_function for other domains
6264
self.timestamps = np.arange(self.length)/fs
@@ -67,7 +69,7 @@ def __init__(self, mode='t', data=[0 for _ in range(9)], fs=1, evenlen=True):
6769
self.mean = np.mean(self.mode_t)
6870
self.rms = np.mean([x**2 for x in self.mode_t])**0.5
6971

70-
self.analytic = sp.signal.hilbert(self.mode_t)
72+
self.analytic = scipy.signal.hilbert(self.mode_t)
7173
self.hilbert = np.real(self.analytic)
7274
self.inst_phase = np.unwrap(np.angle(self.analytic))
7375

@@ -91,7 +93,7 @@ def from_function(cls, mode='t', func=(lambda x: 0), fs=1, num_samples=9, evenle
9193
if mode in ['t', 'ttilde']:
9294
sample_points = np.arange(num_samples)/fs
9395
else:
94-
sample_points = sp.fft.rfftfreq(2*num_samples-1-int(evenlen), 1/fs)
96+
sample_points = scipy.fft.rfftfreq(2*num_samples-1-int(evenlen), 1/fs)
9597
S = cls(mode=mode, data=func(sample_points), fs=fs, evenlen=evenlen)
9698
S.func = func
9799
return S
@@ -125,17 +127,17 @@ def linear_chirp(cls, f_init=1, w_init=None, f_final=10, w_final=None, fs=1, num
125127
return cls.from_function(mode='t', func=(lambda t: np.cos(np.pi*t*(fi*(2-t/endtime) + ff*(t/endtime)))), fs=fs, num_samples=num_samples, evenlen=(num_samples%2==0))
126128
# evenlen is actually never used in this case since the mode is always 't', but just for completeness
127129

128-
# @classmethod
129-
# def from_instantaneous_frequency(cls, freq_func=(lambda x: 0), freqs=None, init_phase=0, fs=1, num_samples=9):
130-
# if freqs is None:
131-
# # gaussian quadrature?
132-
# ws = [2*np.pi*freq_func(i/fs) for i in range(num_samples)]
133-
# else:
134-
# ws = [2*np.pi*f for f in freqs]
130+
@classmethod
131+
def from_instantaneous_frequency(cls, freq_func=(lambda x: 0), freqs=None, init_phase=0, fs=1, num_samples=9):
132+
if freqs is None:
133+
# gaussian quadrature?
134+
ws = [2*np.pi*freq_func(i/fs) for i in range(num_samples)]
135+
else:
136+
ws = [2*np.pi*f for f in freqs]
135137

136-
# phases = sp.integrate.cumulative_trapezoid(ws, dx=1/fs, initial=0)+init_phase
138+
phases = scipy.integrate.cumulative_trapezoid(ws, dx=1/fs, initial=0)+init_phase
137139

138-
# return cls(mode='t', data=np.cos(phases).tolist(), fs=fs)
140+
return cls(mode='t', data=np.cos(phases).tolist(), fs=fs)
139141

140142
def __iter__(self):
141143
self.__idx = -1
@@ -253,13 +255,15 @@ def at_time(self, t, tolerance=1e-10):
253255
if tolerate(num, eps=tolerance):
254256
return self.mode_t[round(num)%self.length]
255257

256-
f = [k/self.length for k in self.mode_f]
258+
f = self.mode_f
259+
freqstamps = self.freqstamps
257260

258-
tot = f[0].real
261+
tot = np.real(f[0])
259262
for idx in range(1, len(f)):
260-
tot += 2 * (f[idx] * np.exp(2j*np.pi*idx*t/self.length)).real
263+
tot += 2 * np.real(f[idx] * np.exp(2j*np.pi*t*freqstamps[idx]))
261264
if self.length%2 == 0:
262-
tot -= (f[idx] * np.exp(2j*np.pi*idx*t/self.length)).real
265+
tot -= np.real(f[idx] * np.exp(2j*np.pi*t*freqstamps[idx]))
266+
tot /= self.length
263267
if tolerate(tot, eps=tolerance):
264268
tot = round(tot)
265269
return tot
@@ -358,12 +362,12 @@ def moving_spectral_entropy(self, window_len=9):
358362
Hs = []
359363
for i in range(self.length-window_len+1):
360364
data = self.mode_t[i:i+9]
361-
spectrum = abs(sp.fft.fft(data))**2
365+
spectrum = abs(scipy.fft.fft(data))**2
362366
distribution = spectrum/sum(spectrum)
363367
Hs += [-sum(p*np.log(p) for p in distribution)/np.log(window_len)]
364368
return Hs
365369

366-
def spectrogram(self, win=sp.signal.windows.gaussian(30, std=5, sym=True), hop=1, mfft=200, custom_title='Spectrogram', show=True):
370+
def spectrogram(self, win=scipy.signal.windows.gaussian(30, std=5, sym=True), hop=1, mfft=200, custom_title='Spectrogram', show=True):
367371
'''
368372
Generates spectrogram of Signal. Returns [SFFT data, bounds]. \
369373
Since the window has a small width, the resulting SFFT is \
@@ -377,7 +381,7 @@ def spectrogram(self, win=sp.signal.windows.gaussian(30, std=5, sym=True), hop=1
377381
show: `True` if plot is to be shown, `False` otherwise. Default is `True`.
378382
'''
379383
N = self.length
380-
ft = sp.signal.ShortTimeFFT(np.array(win), hop, self.fs, mfft=mfft)
384+
ft = scipy.signal.ShortTimeFFT(np.array(win), hop, self.fs, mfft=mfft)
381385
S = ft.spectrogram(np.array(self.mode_t))
382386
windowed_S = S[:, ft.lower_border_end[1]:ft.upper_border_begin(N)[1]]
383387
bounds = (0, ft.delta_t*len(windowed_S[0]), *ft.extent(N)[2:])
@@ -471,4 +475,4 @@ def as_sound(self, filename):
471475
Attributes:
472476
filename: Filename to save sound file under.
473477
'''
474-
sp.io.wavfile.write(filename, self.fs, np.array(self.mode_t))
478+
scipy.io.wavfile.write(filename, self.fs, np.array(self.mode_t))

0 commit comments

Comments
 (0)