diff --git a/.github/scripts/unittest-linux/install.sh b/.github/scripts/unittest-linux/install.sh index 8e550dbf2d..c04e9747d0 100755 --- a/.github/scripts/unittest-linux/install.sh +++ b/.github/scripts/unittest-linux/install.sh @@ -40,7 +40,7 @@ case $GPU_ARCH_TYPE in ;; esac PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${GPU_ARCH_ID}" -pip install --progress-bar=off --pre torch torchcodec --index-url="${PYTORCH_WHEEL_INDEX}" +pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" # 2. Install torchaudio @@ -54,6 +54,5 @@ pip install . -v --no-build-isolation printf "* Installing test tools\n" # On this CI, for whatever reason, we're only able to install ffmpeg 4. conda install -y "ffmpeg<5" -python -c "import torch; import torchaudio; import torchcodec; print(torch.__version__, torchaudio.__version__, torchcodec.__version__)" pip3 install parameterized requests coverage pytest pytest-cov scipy numpy expecttest diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index e92c556218..f681e3b7ec 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -68,7 +68,7 @@ jobs: GPU_ARCH_ID=cu126 # This is hard-coded and must be consistent with gpu-arch-version. PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}" - pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" + pip install --progress-bar=off --pre torch torchcodec --index-url="${PYTORCH_WHEEL_INDEX}" echo "::endgroup::" echo "::group::Install TorchAudio" diff --git a/src/torchaudio/__init__.py b/src/torchaudio/__init__.py index e533cafe9d..78f42e5cfb 100644 --- a/src/torchaudio/__init__.py +++ b/src/torchaudio/__init__.py @@ -1,4 +1,7 @@ from torchaudio._internal.module_utils import dropping_io_support, dropping_class_io_support +from typing import Union, BinaryIO, Optional, Tuple +import os +import torch # Initialize extension and backend first from . import _extension # noqa # usort: skip @@ -7,8 +10,6 @@ get_audio_backend as _get_audio_backend, info as _info, list_audio_backends as _list_audio_backends, - load, - save, set_audio_backend as _set_audio_backend, ) from ._torchcodec import load_with_torchcodec, save_with_torchcodec @@ -41,6 +42,172 @@ pass +def load( + uri: Union[BinaryIO, str, os.PathLike], + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, + buffer_size: int = 4096, + backend: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Load audio data from source using TorchCodec's AudioDecoder. + + .. note:: + + As of TorchAudio 2.9, this function relies on TorchCodec's decoding capabilities under the hood. It is + provided for convenience, but we do recommend that you port your code to + natively use ``torchcodec``'s ``AudioDecoder`` class for better + performance: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder. + Because of the reliance on Torchcodec, the parameters ``normalize``, ``buffer_size``, and + ``backend`` are ignored and accepted only for backwards compatibility. + + + Args: + uri (path-like object or file-like object): + Source of audio data. The following types are accepted: + + * ``path-like``: File path or URL. + * ``file-like``: Object with ``read(size: int) -> bytes`` method. + + frame_offset (int, optional): + Number of samples to skip before start reading data. + num_frames (int, optional): + Maximum number of samples to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + normalize (bool, optional): + TorchCodec always returns normalized float32 samples. This parameter + is ignored and a warning is issued if set to False. + Default: ``True``. + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Format hint for the decoder. May not be supported by all TorchCodec + decoders. (Default: ``None``) + buffer_size (int, optional): + Not used by TorchCodec AudioDecoder. Provided for API compatibility. + backend (str or None, optional): + Not used by TorchCodec AudioDecoder. Provided for API compatibility. + + Returns: + (torch.Tensor, int): Resulting Tensor and sample rate. + Always returns float32 tensors. If ``channels_first=True``, shape is + `[channel, time]`, otherwise `[time, channel]`. + + Raises: + ImportError: If torchcodec is not available. + ValueError: If unsupported parameters are used. + RuntimeError: If TorchCodec fails to decode the audio. + + Note: + - TorchCodec always returns normalized float32 samples, so the ``normalize`` + parameter has no effect. + - The ``buffer_size`` and ``backend`` parameters are ignored. + - Not all audio formats supported by torchaudio backends may be supported + by TorchCodec. + """ + return load_with_torchcodec( + uri, + frame_offset=frame_offset, + num_frames=num_frames, + normalize=normalize, + channels_first=channels_first, + format=format, + buffer_size=buffer_size, + backend=backend + ) + +def save( + uri: Union[str, os.PathLike], + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, + buffer_size: int = 4096, + backend: Optional[str] = None, + compression: Optional[Union[float, int]] = None, +) -> None: + """Save audio data to file using TorchCodec's AudioEncoder. + + .. note:: + + As of TorchAudio 2.9, this function relies on TorchCodec's encoding capabilities under the hood. + It is provided for convenience, but we do recommend that you port your code to + natively use ``torchcodec``'s ``AudioEncoder`` class for better + performance: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder. + Because of the reliance on Torchcodec, the parameters ``format``, ``encoding``, + ``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored and accepted only for + backwards compatibility. + + Args: + uri (path-like object): + Path to save the audio file. The file extension determines the format. + + src (torch.Tensor): + Audio data to save. Must be a 1D or 2D tensor with float32 values + in the range [-1, 1]. If 2D, shape should be [channel, time] when + channels_first=True, or [time, channel] when channels_first=False. + + sample_rate (int): + Sample rate of the audio data. + + channels_first (bool, optional): + Indicates whether the input tensor has channels as the first dimension. + If True, expects [channel, time]. If False, expects [time, channel]. + Default: True. + + format (str or None, optional): + Audio format hint. Not used by TorchCodec (format is determined by + file extension). A warning is issued if provided. + Default: None. + + encoding (str or None, optional): + Audio encoding. Not fully supported by TorchCodec AudioEncoder. + A warning is issued if provided. Default: None. + + bits_per_sample (int or None, optional): + Bits per sample. Not directly supported by TorchCodec AudioEncoder. + A warning is issued if provided. Default: None. + + buffer_size (int, optional): + Not used by TorchCodec AudioEncoder. Provided for API compatibility. + A warning is issued if not default value. Default: 4096. + + backend (str or None, optional): + Not used by TorchCodec AudioEncoder. Provided for API compatibility. + A warning is issued if provided. Default: None. + + compression (float, int or None, optional): + Compression level or bit rate. Maps to bit_rate parameter in + TorchCodec AudioEncoder. Default: None. + + Raises: + ImportError: If torchcodec is not available. + ValueError: If input parameters are invalid. + RuntimeError: If TorchCodec fails to encode the audio. + + Note: + - TorchCodec AudioEncoder expects float32 samples in [-1, 1] range. + - Some parameters (format, encoding, bits_per_sample, buffer_size, backend) + are not used by TorchCodec but are provided for API compatibility. + - The output format is determined by the file extension in the uri. + - TorchCodec uses FFmpeg under the hood for encoding. + """ + return save_with_torchcodec(uri, src, sample_rate, + channels_first=channels_first, + format=format, + encoding=encoding, + bits_per_sample=bits_per_sample, + buffer_size=buffer_size, + backend=backend, + compression=compression) + __all__ = [ "AudioMetaData", "load", diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000000..3b16aab043 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,23 @@ +import sys +from pathlib import Path + +# Note: [TorchCodec test dependency mocking hack] +# We are adding the `test/` directory to the system path. This causes the +# `tests/torchcodec` folder to be importable, and in particular, this makes it +# possible to mock torchcodec utilities. E.g. executing: +# +# ``` +# from torchcodec.decoders import AudioDecoder +# ``` +# directly or indirectly when running the tests will effectively be loading the +# mocked `AudioDecoder` implemented in `test/torchcodec/decoders.py`, which +# relies on scipy instead of relying on torchcodec. +# +# So whenever `torchaudio.load()` is called from within the tests, it's the +# mocked scipy `AudioDecoder` that gets used. Ultimately, this allows us *not* +# to add torchcodec as a test dependency of torchaudio: we can just rely on +# scipy. +# +# This is VERY hacky and ideally we should implement a more robust way to mock +# torchcodec. +sys.path.append(str(Path(__file__).parent.resolve())) diff --git a/test/torchaudio_unittest/test_load_save_torchcodec.py b/test/torchaudio_unittest/test_load_save_torchcodec.py index 3edb4c423b..9057e93811 100644 --- a/test/torchaudio_unittest/test_load_save_torchcodec.py +++ b/test/torchaudio_unittest/test_load_save_torchcodec.py @@ -12,6 +12,15 @@ from torchaudio import load_with_torchcodec, save_with_torchcodec from torchaudio_unittest.common_utils import get_asset_path +# These tests were ran when `torchaudio.load()` and `torchaudio.save()` were +# still relying on their previous backends (ffmpeg, sox, soundfile). We needed +# to validate that the newly introduced `load_with_torchcodec()` and +# save_with_torchcodec() were matching their results. +# From 2.9, `load()` and `save()` now internally rely on `load_with_torchcodec()` and +# `save_with_torchcodec()` directly, so these tests are now redundant and we +# skip them unconditionally. +pytest.skip(allow_module_level=True) + def get_ffmpeg_version(): """Get FFmpeg version to check for compatibility issues.""" try: @@ -48,25 +57,25 @@ def test_basic_load(filename): # Skip problematic files on FFmpeg4 due to known compatibility issues if is_ffmpeg4() and filename != "sinewave.wav": pytest.skip("FFmpeg4 has known compatibility issues with some audio files") - + file_path = get_asset_path(*filename.split("/")) - + # Load with torchaudio waveform_ta, sample_rate_ta = torchaudio.load(file_path) - + # Load with torchcodec waveform_tc, sample_rate_tc = load_with_torchcodec(file_path) - + # Check sample rates match assert sample_rate_ta == sample_rate_tc - + # Check shapes match assert waveform_ta.shape == waveform_tc.shape - + # Check data types (should both be float32) assert waveform_ta.dtype == torch.float32 assert waveform_tc.dtype == torch.float32 - + # Check values are close (allowing for small differences in decoders) torch.testing.assert_close(waveform_ta, waveform_tc) @@ -79,17 +88,17 @@ def test_basic_load(filename): def test_frame_offset_and_num_frames(frame_offset, num_frames): """Test frame_offset and num_frames parameters.""" file_path = get_asset_path("sinewave.wav") - + # Load with torchaudio waveform_ta, sample_rate_ta = torchaudio.load( file_path, frame_offset=frame_offset, num_frames=num_frames ) - + # Load with torchcodec waveform_tc, sample_rate_tc = load_with_torchcodec( file_path, frame_offset=frame_offset, num_frames=num_frames ) - + # Check results match assert sample_rate_ta == sample_rate_tc assert waveform_ta.shape == waveform_tc.shape @@ -98,21 +107,21 @@ def test_frame_offset_and_num_frames(frame_offset, num_frames): def test_channels_first(): """Test channels_first parameter.""" file_path = get_asset_path("sinewave.wav") # Use sinewave.wav for compatibility - + # Test channels_first=True (default) waveform_cf_true, sample_rate = load_with_torchcodec(file_path, channels_first=True) - + # Test channels_first=False waveform_cf_false, _ = load_with_torchcodec(file_path, channels_first=False) - + # Check that transpose relationship holds assert waveform_cf_true.shape == waveform_cf_false.transpose(0, 1).shape torch.testing.assert_close(waveform_cf_true, waveform_cf_false.transpose(0, 1)) - + # Compare with torchaudio waveform_ta_true, _ = torchaudio.load(file_path, channels_first=True) waveform_ta_false, _ = torchaudio.load(file_path, channels_first=False) - + assert waveform_cf_true.shape == waveform_ta_true.shape assert waveform_cf_false.shape == waveform_ta_false.shape torch.testing.assert_close(waveform_cf_true, waveform_ta_true) @@ -121,18 +130,18 @@ def test_channels_first(): def test_normalize_parameter_warning(): """Test that normalize=False produces a warning.""" file_path = get_asset_path("sinewave.wav") - + with pytest.warns(UserWarning, match="normalize=False.*ignored"): # This should produce a warning waveform, sample_rate = load_with_torchcodec(file_path, normalize=False) - + # Result should still be float32 (normalized) assert waveform.dtype == torch.float32 def test_buffer_size_parameter_warning(): """Test that non-default buffer_size produces a warning.""" file_path = get_asset_path("sinewave.wav") - + with pytest.warns(UserWarning, match="buffer_size.*not used"): # This should produce a warning waveform, sample_rate = load_with_torchcodec(file_path, buffer_size=8192) @@ -141,7 +150,7 @@ def test_buffer_size_parameter_warning(): def test_backend_parameter_warning(): """Test that specifying backend produces a warning.""" file_path = get_asset_path("sinewave.wav") - + with pytest.warns(UserWarning, match="backend.*not used"): # This should produce a warning waveform, sample_rate = load_with_torchcodec(file_path, backend="ffmpeg") @@ -156,10 +165,10 @@ def test_invalid_file(): def test_format_parameter(): """Test that format parameter produces a warning.""" file_path = get_asset_path("sinewave.wav") - + with pytest.warns(UserWarning, match="format.*not supported"): waveform, sample_rate = load_with_torchcodec(file_path, format="wav") - + # Check basic properties assert waveform.dtype == torch.float32 assert sample_rate > 0 @@ -168,17 +177,17 @@ def test_format_parameter(): def test_multiple_warnings(): """Test that multiple unsupported parameters produce multiple warnings.""" file_path = get_asset_path("sinewave.wav") - + with pytest.warns() as warning_list: # This should produce multiple warnings waveform, sample_rate = load_with_torchcodec( - file_path, - normalize=False, - buffer_size=8192, + file_path, + normalize=False, + buffer_size=8192, backend="ffmpeg" ) - - + + # Check that expected warnings are present messages = [str(w.message) for w in warning_list] assert any("normalize=False" in msg for msg in messages) @@ -194,30 +203,30 @@ def test_save_basic_save(filename): # Load a test file first file_path = get_asset_path(*filename.split("/")) waveform, sample_rate = torchaudio.load(file_path) - + with tempfile.TemporaryDirectory() as temp_dir: # Save with torchaudio ta_path = os.path.join(temp_dir, "ta_output.wav") torchaudio.save(ta_path, waveform, sample_rate) - + # Save with torchcodec tc_path = os.path.join(temp_dir, "tc_output.wav") save_with_torchcodec(tc_path, waveform, sample_rate) - + # Load both back and compare waveform_ta, sample_rate_ta = torchaudio.load(ta_path) waveform_tc, sample_rate_tc = torchaudio.load(tc_path) - + # Check sample rates match assert sample_rate_ta == sample_rate_tc - + # Check shapes match assert waveform_ta.shape == waveform_tc.shape - + # Check data types (should both be float32) assert waveform_ta.dtype == torch.float32 assert waveform_tc.dtype == torch.float32 - + # Check values are close (allowing for small differences in encoders) torch.testing.assert_close(waveform_ta, waveform_tc, atol=1e-3, rtol=1e-3) @@ -227,25 +236,25 @@ def test_save_channels_first(channels_first): """Test channels_first parameter.""" # Create test data if channels_first: - waveform = torch.randn(2, 16000) # [channel, time] + waveform = torch.rand(2, 16000) # [channel, time] else: - waveform = torch.randn(16000, 2) # [time, channel] - + waveform = torch.rand(16000, 2) # [time, channel] + sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: # Save with torchaudio ta_path = os.path.join(temp_dir, "ta_output.wav") torchaudio.save(ta_path, waveform, sample_rate, channels_first=channels_first) - + # Save with torchcodec tc_path = os.path.join(temp_dir, "tc_output.wav") save_with_torchcodec(tc_path, waveform, sample_rate, channels_first=channels_first) - + # Load both back and compare waveform_ta, sample_rate_ta = torchaudio.load(ta_path) waveform_tc, sample_rate_tc = torchaudio.load(tc_path) - + # Check results match assert sample_rate_ta == sample_rate_tc assert waveform_ta.shape == waveform_tc.shape @@ -256,15 +265,15 @@ def test_save_compression_parameter(): """Test compression parameter (maps to bit_rate).""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: # Test with compression (bit_rate) output_path = os.path.join(temp_dir, "output.wav") save_with_torchcodec(output_path, waveform, sample_rate, compression=128000) - + # Should not raise an error and file should exist assert os.path.exists(output_path) - + # Load back and check basic properties waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) assert sample_rate_loaded == sample_rate @@ -275,13 +284,13 @@ def test_save_format_parameter_warning(): """Test that format parameter produces a warning.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns(UserWarning, match="format.*not used"): save_with_torchcodec(output_path, waveform, sample_rate, format="wav") - + # Should still work despite warning assert os.path.exists(output_path) @@ -290,13 +299,13 @@ def test_save_encoding_parameter_warning(): """Test that encoding parameter produces a warning.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns(UserWarning, match="encoding.*not fully supported"): save_with_torchcodec(output_path, waveform, sample_rate, encoding="PCM_16") - + # Should still work despite warning assert os.path.exists(output_path) @@ -305,13 +314,13 @@ def test_save_bits_per_sample_parameter_warning(): """Test that bits_per_sample parameter produces a warning.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns(UserWarning, match="bits_per_sample.*not directly supported"): save_with_torchcodec(output_path, waveform, sample_rate, bits_per_sample=16) - + # Should still work despite warning assert os.path.exists(output_path) @@ -320,13 +329,13 @@ def test_save_buffer_size_parameter_warning(): """Test that non-default buffer_size produces a warning.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns(UserWarning, match="buffer_size.*not used"): save_with_torchcodec(output_path, waveform, sample_rate, buffer_size=8192) - + # Should still work despite warning assert os.path.exists(output_path) @@ -335,13 +344,13 @@ def test_save_backend_parameter_warning(): """Test that specifying backend produces a warning.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns(UserWarning, match="backend.*not used"): save_with_torchcodec(output_path, waveform, sample_rate, backend="ffmpeg") - + # Should still work despite warning assert os.path.exists(output_path) @@ -350,16 +359,16 @@ def test_save_edge_cases(): """Test edge cases and error conditions.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + # Test with very small waveform small_waveform = torch.randn(1, 10) save_with_torchcodec(output_path, small_waveform, sample_rate) waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) assert sample_rate_loaded == sample_rate - + # Test with different sample rates for sr in [8000, 22050, 44100]: sr_path = os.path.join(temp_dir, f"output_{sr}.wav") @@ -372,19 +381,19 @@ def test_save_invalid_inputs(): """Test that invalid inputs raise appropriate errors.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + # Test with invalid sample rate with pytest.raises(ValueError, match="sample_rate must be positive"): save_with_torchcodec(output_path, waveform, -1) - + # Test with invalid tensor dimensions with pytest.raises(ValueError, match="Expected 1D or 2D tensor"): invalid_waveform = torch.randn(1, 2, 16000) # 3D tensor save_with_torchcodec(output_path, invalid_waveform, sample_rate) - + # Test with non-tensor input with pytest.raises(ValueError, match="Expected src to be a torch.Tensor"): save_with_torchcodec(output_path, [1, 2, 3], sample_rate) @@ -394,14 +403,14 @@ def test_save_multiple_warnings(): """Test that multiple unsupported parameters produce multiple warnings.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, "output.wav") - + with pytest.warns() as warning_list: save_with_torchcodec( - output_path, - waveform, + output_path, + waveform, sample_rate, format="wav", encoding="PCM_16", @@ -409,7 +418,7 @@ def test_save_multiple_warnings(): buffer_size=8192, backend="ffmpeg" ) - + # Check that expected warnings are present messages = [str(w.message) for w in warning_list] assert any("format" in msg for msg in messages) @@ -417,7 +426,7 @@ def test_save_multiple_warnings(): assert any("bits_per_sample" in msg for msg in messages) assert any("buffer_size" in msg for msg in messages) assert any("backend" in msg for msg in messages) - + # Should still work despite warnings assert os.path.exists(output_path) @@ -426,17 +435,17 @@ def test_save_different_formats(): """Test saving to different audio formats.""" waveform = torch.randn(1, 16000) sample_rate = 16000 - + with tempfile.TemporaryDirectory() as temp_dir: # Test common formats formats = ["wav", "mp3", "flac"] - + for fmt in formats: output_path = os.path.join(temp_dir, f"output.{fmt}") try: save_with_torchcodec(output_path, waveform, sample_rate) assert os.path.exists(output_path) - + # Try to load back (may not work for all formats with all backends) try: waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) @@ -446,4 +455,4 @@ def test_save_different_formats(): pass except Exception as e: # Some formats might not be supported by torchcodec - pytest.skip(f"Format {fmt} not supported: {e}") \ No newline at end of file + pytest.skip(f"Format {fmt} not supported: {e}") diff --git a/test/torchcodec/decoders.py b/test/torchcodec/decoders.py new file mode 100644 index 0000000000..0064be91d6 --- /dev/null +++ b/test/torchcodec/decoders.py @@ -0,0 +1,15 @@ +import torchaudio_unittest.common_utils.wav_utils as wav_utils +from types import SimpleNamespace + +# See corresponding [TorchCodec test dependency mocking hack] note in +# conftest.py + +class AudioDecoder: + def __init__(self, uri): + self.uri = uri + data, sample_rate = wav_utils.load_wav(self.uri) + self.metadata = SimpleNamespace(sample_rate=sample_rate) + self.data = data + + def get_all_samples(self): + return SimpleNamespace(data=self.data) diff --git a/test/torchcodec/encoders.py b/test/torchcodec/encoders.py new file mode 100644 index 0000000000..e6b0693018 --- /dev/null +++ b/test/torchcodec/encoders.py @@ -0,0 +1,13 @@ +import torchaudio_unittest.common_utils.wav_utils as wav_utils +from types import SimpleNamespace + +# See corresponding [TorchCodec test dependency mocking hack] note in +# conftest.py + +class AudioEncoder: + def __init__(self, data, sample_rate): + self.data = data + self.metadata = SimpleNamespace(sample_rate=sample_rate) + + def to_file(self, uri, bit_rate=None): + return wav_utils.save_wav(uri, self.data, self.metadata.sample_rate)