Skip to content

Commit a6227e5

Browse files
authored
Enhancement: compute the correlation between two matches instead of only counting the matches (#1)
* tmp: add correlation * cli: add a record command * add missing apply window of the Hamming window function * enhancement: Add a scoring function that uses the correlation + the count of matches for better results when matching against the database * cli: listen cmd should remove the temporary file created * spectrogram: make the windowing function optional to allow for easy testing
1 parent 333efff commit a6227e5

19 files changed

+221
-102
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ bin/*
1515

1616
# Ignore dataset by default
1717
assets/dataset/*
18+
19+
.idea/

cmd/musig/cmd/listen.go

+8-25
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
package cmd
22

33
import (
4-
"io/ioutil"
4+
log "github.com/sirupsen/logrus"
55
"os"
6-
"os/signal"
6+
"path"
77
"time"
88

9-
"github.com/sfluor/musig/internal/pkg/sound"
109
"github.com/spf13/cobra"
1110
)
1211

@@ -20,33 +19,17 @@ var listenCmd = &cobra.Command{
2019
Use: "listen",
2120
Short: "listen will record the microphone input and try to find a matching song from the database (Ctrl-C will stop the recording)",
2221
Run: func(cmd *cobra.Command, args []string) {
23-
tmpFile, err := ioutil.TempFile(os.TempDir(), "musig_record.wav")
24-
failIff(err, "error creating temporary file for recording in %s", tmpFile.Name())
25-
defer os.Remove(tmpFile.Name())
26-
27-
stopCh := make(chan struct{}, 1)
28-
sig := make(chan os.Signal, 1)
29-
signal.Notify(sig, os.Interrupt, os.Kill)
30-
22+
name := path.Join(os.TempDir(), "musig_record.wav")
3123
dur, err := cmd.Flags().GetDuration("duration")
3224
failIff(err, "could not get duration, got: %v", dur)
3325

34-
go func() {
35-
defer func() { stopCh <- struct{}{} }()
36-
for {
37-
select {
38-
case <-time.After(dur):
39-
return
40-
case <-sig:
41-
return
42-
}
26+
defer func() {
27+
if err := os.Remove(name); err != nil {
28+
log.Errorf("Failed to remove temporary file stored at %s used to record the sample: %s", name , err)
4329
}
4430
}()
4531

46-
err = sound.RecordWAV(tmpFile, stopCh)
47-
failIff(err, "an error occured recording WAV file")
48-
failIff(tmpFile.Sync(), "error syncing temp file")
49-
50-
cmdRead(tmpFile.Name())
32+
recordAudioToFile(name, dur)
33+
cmdRead(name)
5134
},
5235
}

cmd/musig/cmd/load.go

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
func init() {
1313
rootCmd.AddCommand(loadCmd)
1414
loadCmd.Flags().BoolP("dry-run", "d", false, "disable saving to the database")
15+
loadCmd.Flags().BoolP("reset", "r", false, "reset the database if it already exists")
1516
loadCmd.Flags().BoolP("verbose", "v", false, "enable verbose output")
1617
}
1718

@@ -29,6 +30,14 @@ var loadCmd = &cobra.Command{
2930
os.Exit(0)
3031
}
3132

33+
resetDB, err := cmd.Flags().GetBool("reset")
34+
if resetDB && err == nil {
35+
log.Info("removing the existing database...")
36+
if err := os.Remove(dbFile); err != nil {
37+
log.Errorf("Error removing the database at %s: %s", dbFile, err)
38+
}
39+
}
40+
3241
p, err := pipeline.NewDefaultPipeline(dbFile)
3342
failIff(err, "error creating pipeline")
3443
defer p.Close()

cmd/musig/cmd/read.go

+26-13
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package cmd
22

33
import (
44
"fmt"
5-
65
"github.com/sfluor/musig/internal/pkg/model"
76
"github.com/sfluor/musig/internal/pkg/pipeline"
7+
"github.com/sfluor/musig/pkg/dsp"
88
"github.com/spf13/cobra"
99
)
1010

@@ -29,34 +29,47 @@ func cmdRead(file string) {
2929
res, err := p.Process(file)
3030
failIff(err, "error processing file %s", file)
3131

32-
// Will hold a count of songID => occurences
33-
counts := map[uint32]int{}
32+
// Will hold a count of songID => occurrences
3433
keys := make([]model.EncodedKey, 0, len(res.Fingerprint))
34+
sample := map[model.EncodedKey]model.TableValue{}
35+
// songID => points that matched
36+
matches := map[uint32]map[model.EncodedKey]model.TableValue{}
3537

36-
for k := range res.Fingerprint {
38+
for k, v := range res.Fingerprint {
3739
keys = append(keys, k)
40+
sample[k] = v
3841
}
3942

4043
m, err := p.DB.Get(keys)
41-
for _, values := range m {
44+
for key, values := range m {
4245
for _, val := range values {
43-
counts[val.SongID] += 1
46+
47+
if _, ok := matches[val.SongID]; !ok {
48+
matches[val.SongID] = map[model.EncodedKey]model.TableValue{}
49+
}
50+
51+
matches[val.SongID][key] = val
4452
}
4553
}
4654

55+
// songID => correlation
56+
scores := map[uint32]float64{}
57+
for songID, points := range matches {
58+
scores[songID] = dsp.MatchScore(sample, points)
59+
}
60+
4761
var song string
48-
var max, total int
62+
var max float64
4963
fmt.Println("Matches:")
50-
for id, count := range counts {
64+
for id, score := range scores {
5165
name, err := p.DB.GetSong(id)
5266
failIff(err, "error getting song id: %d", id)
53-
fmt.Printf("\t- %s, count: %d\n", name, count)
54-
if count > max {
55-
song, max = name, count
67+
fmt.Printf("\t- %s, score: %f\n", name, score)
68+
if score > max {
69+
song, max = name, score
5670
}
57-
total += count
5871
}
5972

6073
fmt.Println("---")
61-
fmt.Printf("Song is: %s (count: %d, pct: %.2f %%)\n", song, max, 100*float64(max)/float64(total))
74+
fmt.Printf("Song is: %s (score: %f)\n", song, max)
6275
}

cmd/musig/cmd/record.go

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package cmd
2+
3+
import (
4+
"os"
5+
"os/signal"
6+
"time"
7+
8+
"github.com/sfluor/musig/internal/pkg/sound"
9+
"github.com/spf13/cobra"
10+
)
11+
12+
func init() {
13+
rootCmd.AddCommand(recordCmd)
14+
recordCmd.Flags().DurationP("duration", "d", 10*time.Second, "duration of the listening")
15+
}
16+
17+
// recordCmd represents the listen command
18+
var recordCmd = &cobra.Command{
19+
Use: "record",
20+
Short: "record will record the microphone input and save the signal to the given file",
21+
Args: cobra.MinimumNArgs(1),
22+
Run: func(cmd *cobra.Command, args []string) {
23+
dur, err := cmd.Flags().GetDuration("duration")
24+
failIff(err, "could not get duration, got: %v", dur)
25+
26+
recordAudioToFile(args[0], dur)
27+
},
28+
}
29+
30+
func recordAudioToFile(name string, duration time.Duration) {
31+
file, err := os.Create(name)
32+
failIff(err, "error creating file for recording in %s", name)
33+
34+
stopCh := make(chan struct{}, 1)
35+
sig := make(chan os.Signal, 1)
36+
signal.Notify(sig, os.Interrupt, os.Kill)
37+
38+
go func() {
39+
defer func() {
40+
stopCh <- struct{}{}
41+
}()
42+
for {
43+
select {
44+
case <-time.After(duration):
45+
return
46+
case <-sig:
47+
return
48+
}
49+
}
50+
}()
51+
52+
err = sound.RecordWAV(file, stopCh)
53+
failIff(err, "an error occurred recording WAV file")
54+
failIff(file.Sync(), "error syncing temp file")
55+
}

cmd/musig/cmd/spec.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func genSpectrogram(path string, imgPath string) error {
3636
}
3737
defer file.Close()
3838

39-
s := dsp.NewSpectrogrammer(model.DOWNSAMPLERATIO, model.MAXFREQ, model.SAMPLESIZE)
39+
s := dsp.NewSpectrogrammer(model.DownsampleRatio, model.MaxFreq, model.SampleSize, true)
4040

4141
spec, _, err := s.Spectrogram(file)
4242
if err != nil {

internal/pkg/db/bolt.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (db *BoltDB) Get(keys []model.EncodedKey) (map[model.EncodedKey][]model.Tab
6868
return nil
6969
})
7070

71-
return res, errors.Wrap(err, "an error occured when reading from bolt")
71+
return res, errors.Wrap(err, "an error occurred when reading from bolt")
7272
}
7373

7474
// Set stores the list of (key, value) into the bolt file
@@ -90,7 +90,7 @@ func (db *BoltDB) Set(batch map[model.EncodedKey]model.TableValue) error {
9090
return nil
9191
})
9292

93-
return errors.Wrap(err, "an error occured when writing to bolt")
93+
return errors.Wrap(err, "an error occurred when writing to bolt")
9494
}
9595

9696
// GetSongID does a song name => songID lookup in the database
@@ -113,7 +113,7 @@ func (db *BoltDB) GetSongID(name string) (uint32, error) {
113113
return fmt.Errorf("could not find id for song name: %s", name)
114114
})
115115

116-
return id, errors.Wrap(err, "an error occured when reading from bolt")
116+
return id, errors.Wrap(err, "an error occurred when reading from bolt")
117117
}
118118

119119
// GetSong does a songID => song name lookup in the database
@@ -133,7 +133,7 @@ func (db *BoltDB) GetSong(songID uint32) (string, error) {
133133
return nil
134134
})
135135

136-
return name, errors.Wrap(err, "an error occured when reading from bolt")
136+
return name, errors.Wrap(err, "an error occurred when reading from bolt")
137137
}
138138

139139
// SetSong stores a song name in the database and returns it's song ID
@@ -163,7 +163,7 @@ func (db *BoltDB) SetSong(song string) (uint32, error) {
163163
return errors.Wrap(b.Put(rawKey, []byte(song)), "error setting song")
164164
})
165165

166-
return songID, errors.Wrap(err, "an error occured when writing to bolt")
166+
return songID, errors.Wrap(err, "an error occurred when writing to bolt")
167167
}
168168

169169
func itob(s uint32) []byte {

internal/pkg/fingerprint/fingerprint_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ func TestFingerprinting440And880(t *testing.T) {
2222
}
2323

2424
func testFingerprintingOnFile(t *testing.T, path string) {
25-
sampleSize := model.SAMPLESIZE
25+
sampleSize := model.SampleSize
2626

2727
s := dsp.NewSpectrogrammer(
28-
model.DOWNSAMPLERATIO,
29-
model.MAXFREQ,
28+
model.DownsampleRatio,
29+
model.MaxFreq,
3030
sampleSize,
31+
true,
3132
)
3233

3334
file, err := os.Open(path)

internal/pkg/fingerprint/simple.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func (sf *SimpleFingerprinter) Fingerprint(songID uint32, cMap []model.Constella
3434
for i := 0; i+sf.lastOffset < length; i++ {
3535
anchor := cMap[i]
3636
for _, p := range cMap[i+sf.anchorOffset : i+sf.lastOffset] {
37-
res[model.NewTableKey(anchor, p).Encode()] = *model.NewTableValue(songID, anchor)
37+
res[model.NewAnchorKey(anchor, p).Encode()] = *model.NewTableValue(songID, anchor)
3838
}
3939
}
4040

internal/pkg/model/model.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import (
66
"strings"
77
)
88

9-
// DOWNSAMPLERATIO is the default down sample ratio (4)
10-
const DOWNSAMPLERATIO = 4
9+
// DownsampleRatio is the default down sample ratio (4)
10+
const DownsampleRatio = 4
1111

12-
// SAMPLESIZE is the default sample size (1024)
13-
const SAMPLESIZE = 1024.0
12+
// SampleSize is the default sample size (1024)
13+
const SampleSize = 1024.0
1414

15-
// MAXFREQ is 5kHz
16-
const MAXFREQ = 5000.0
15+
// MaxFreq is 5kHz
16+
const MaxFreq = 5000.0
1717

1818
// ConstellationPoint represents a point in the constellation map (time + frequency)
1919
type ConstellationPoint struct {

internal/pkg/model/table.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ import (
88

99
// Used to down size the frequences to a 9 bit ints
1010
// XXX: we could also use 10.7Hz directly
11-
const freqStep = MAXFREQ / float64(1<<9)
11+
const freqStep = MaxFreq / float64(1<<9)
1212

1313
// Used to down size the delta times to 14 bit ints (we use 16s as the max duration)
1414
const deltaTimeStep = 16 / float64(1<<14)
1515

1616
// TableValueSize represents the TableValueSize when encoded in bytes
1717
const TableValueSize = 8
1818

19-
// TableKey represents a table key
20-
type TableKey struct {
19+
// AnchorKey represents a anchor key
20+
type AnchorKey struct {
2121
// Frequency of the anchor point for the given point's target zone
2222
AnchorFreq float64
2323
// Frequency of the given point
@@ -29,9 +29,9 @@ type TableKey struct {
2929
// EncodedKey represents an encoded key
3030
type EncodedKey uint32
3131

32-
// NewTableKey creates a new table key from the given anchor and the given point
33-
func NewTableKey(anchor, point ConstellationPoint) *TableKey {
34-
return &TableKey{
32+
// NewAnchorKey creates a new anchor key from the given anchor and the given point
33+
func NewAnchorKey(anchor, point ConstellationPoint) *AnchorKey {
34+
return &AnchorKey{
3535
AnchorFreq: anchor.Freq,
3636
PointFreq: point.Freq,
3737
// Use absolute just in case anchor is after the target zone
@@ -47,13 +47,13 @@ func (ek EncodedKey) Bytes() []byte {
4747
return bk
4848
}
4949

50-
// Encode encodes the table key using:
50+
// Encode encodes the anchor key using:
5151
// 9 bits for the “frequency of the anchor”: fa
5252
// 9 bits for the ” frequency of the point”: fp
5353
// 14 bits for the ”delta time between the anchor and the point”: dt
5454
// The result is then dt | fa | fp
5555
// XXX: this only works if frequencies are coded in 9 bits or less (if we used a 1024 samples FFT, it will be the case)
56-
func (tk *TableKey) Encode() EncodedKey {
56+
func (tk *AnchorKey) Encode() EncodedKey {
5757
// down size params
5858
fp := uint32(tk.PointFreq / freqStep)
5959
fa := uint32(tk.AnchorFreq / freqStep)

internal/pkg/pipeline/pipeline.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func NewDefaultPipeline(dbFile string) (*Pipeline, error) {
2424
if err != nil {
2525
return nil, errors.Wrapf(err, "error connection to database at: %s", dbFile)
2626
}
27-
s := dsp.NewSpectrogrammer(model.DOWNSAMPLERATIO, model.MAXFREQ, model.SAMPLESIZE)
27+
s := dsp.NewSpectrogrammer(model.DownsampleRatio, model.MaxFreq, model.SampleSize, true)
2828
fpr := fingerprint.NewDefaultFingerprinter()
2929

3030
return &Pipeline{

internal/pkg/sound/wav.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55

66
"github.com/gordonklaus/portaudio"
77
"github.com/pkg/errors"
8-
riff "github.com/youpy/go-riff"
9-
wav "github.com/youpy/go-wav"
8+
"github.com/youpy/go-riff"
9+
"github.com/youpy/go-wav"
1010
)
1111

1212
var _ Reader = &WAVReader{}

0 commit comments

Comments
 (0)