Skip to content

Commit

Permalink
Merge pull request #464 from facebookresearch/sync_03_06_2024_071201CEST
Browse files Browse the repository at this point in the history
Release 1.4.0a1
  • Loading branch information
antoine-tran committed Jun 18, 2024
2 parents 72cb16f + f2fbfff commit dff1dae
Show file tree
Hide file tree
Showing 39 changed files with 3,462 additions and 124 deletions.
6 changes: 3 additions & 3 deletions .github/actions/audiocraft_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ runs:
steps:
- uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9
- uses: actions/cache@v2
id: cache
with:
Expand All @@ -21,9 +21,9 @@ runs:
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch torchvision torchaudio
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
pip install xformers
pip install -e '.[dev]'
pip install -e '.[dev,wm]'
- name: System Dependencies
shell: bash
run: |
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.4.0a1] - 2024-06-03

Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559))

Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`.

Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal).

## [1.3.0] - 2024-05-02

Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app.
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_m
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
checkpoint.save_last=false # Using compression model from 616d7b3c
INTEG_WATERMARK = AUDIOCRAFT_DORA_DIR="/tmp/wm_$(USER)" dora run device=cpu dataset.num_workers=0 optim.epochs=1 \
dataset.train.num_samples=10 dataset.valid.num_samples=10 dataset.evaluate.num_samples=10 dataset.generate.num_samples=10 \
logging.level=DEBUG solver=watermark/robustness checkpoint.save_last=false dset=audio/example

default: linter tests

Expand All @@ -29,6 +32,7 @@ tests_integ:
$(INTEG_MBD)
$(INTEG_MUSICGEN)
$(INTEG_AUDIOGEN)
$(INTEG_WATERMARK)


api_docs:
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python -m pip install setuptools wheel
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
python -m pip install -e '.[wm]' # if you want to train a watermarking model
```

We also recommend having `ffmpeg` installed, either through your system or Anaconda:
Expand All @@ -37,6 +38,7 @@ At the moment, AudioCraft contains the training code and inference code for:
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
* [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound.
* [AudioSeal](./docs/WATERMARKING.md): A state-of-the-art audio watermarking.

## Training code

Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.3.0'
__version__ = '1.4.0a1'
122 changes: 121 additions & 1 deletion audiocraft/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa


def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
"""Read audio by picking the most appropriate backend tool based on the audio format.
Args:
Expand Down Expand Up @@ -229,3 +229,123 @@ def audio_write(stem_name: tp.Union[str, Path],
path.unlink()
raise
return path


def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
"""Get the mel-spectrogram from the raw audio.
Args:
y (numpy array): raw input
sr (int): Sampling rate
n_fft (int): Number of samples per FFT. Default is 2048.
hop_length (int): Number of samples between successive frames. Default is 512.
dur (float): Maxium duration to get the spectrograms
Returns:
spectro histogram as a numpy array
"""
import librosa
import librosa.display

spectrogram = librosa.feature.melspectrogram(
y=y, sr=sr, n_fft=n_fft, hop_length=hop_length
)
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
return spectrogram_db


def save_spectrograms(
ys: tp.List[np.ndarray],
sr: int,
path: str,
names: tp.List[str],
n_fft: int = 4096,
hop_length: int = 128,
dur: float = 8.0,
):
"""Plot a spectrogram for an audio file.
Args:
ys: List of audio spectrograms
sr (int): Sampling rate of the audio file. Default is 22050 Hz.
path (str): Path to the plot file.
names: name of each spectrogram plot
n_fft (int): Number of samples per FFT. Default is 2048.
hop_length (int): Number of samples between successive frames. Default is 512.
dur (float): Maxium duration to plot the spectrograms
Returns:
None (plots the spectrogram using matplotlib)
"""
import matplotlib as mpl # type: ignore
import matplotlib.pyplot as plt # type: ignore
import librosa.display

if not names:
names = ["Ground Truth", "Audio Watermarked", "Watermark"]
ys = [wav[: int(dur * sr)] for wav in ys] # crop
assert len(names) == len(
ys
), f"There are {len(ys)} wavs but {len(names)} names ({names})"

# Set matplotlib stuff
BIGGER_SIZE = 10
SMALLER_SIZE = 8
linewidth = 234.8775 # linewidth in pt

plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes
plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = ["Times New Roman"]

plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels
plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)
height = 1.6 * linewidth / 72.0
fig, ax = plt.subplots(
nrows=len(ys),
ncols=1,
sharex=True,
figsize=(linewidth / 72.0, height),
)
fig.tight_layout()

# Plot the spectrogram

for i, ysi in enumerate(ys):
spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length)
if i == 0:
cax = fig.add_axes(
[
ax[0].get_position().x1 + 0.01, # type: ignore
ax[-1].get_position().y0,
0.02,
ax[0].get_position().y1 - ax[-1].get_position().y0,
]
)
fig.colorbar(
mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(
np.min(spectrogram_db), np.max(spectrogram_db)
),
cmap="magma",
),
ax=ax,
orientation="vertical",
format="%+2.0f dB",
cax=cax,
)
librosa.display.specshow(
spectrogram_db,
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="mel",
ax=ax[i],
)
ax[i].set(title=names[i])
ax[i].yaxis.set_label_text(None)
ax[i].label_outer()
fig.savefig(path, bbox_inches="tight")
plt.close()
Loading

0 comments on commit dff1dae

Please sign in to comment.