-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Train speech language ID classification head #450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
am831
wants to merge
91
commits into
facebookresearch:main
Choose a base branch
from
am831:language_id
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
91 commits
Select commit
Hold shift + click to select a range
738c135
classification head class
am831 feac449
finetune script progress
am831 dc6cbdb
add comment
am831 350f6e2
add layers
am831 3c81c28
Model freeze und save classification head
zrthxn 74f1f2d
Implement train loop
zrthxn 3280b64
Implement train loop
zrthxn ca42666
calc loss, class head params
am831 13901cb
Refactor
zrthxn db3df0c
fix errors
am831 1b84a5b
get vector dimensions within classification head
am831 c343cd4
hidden_dim
am831 bc7a37f
log and capture interrupts
am831 d2a65f8
dataset prep and plotting loss
am831 4c65310
Merge branch 'facebookresearch:main' into language_id
am831 24781c9
Language ID Dataloader (#2)
zrthxn ce04f48
Model Fixes (#3)
zrthxn 6584115
Merge branch 'facebookresearch:main' into language_id
am831 2567ebe
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 5efb06c
classification head class
am831 d0e8efc
finetune script progress
am831 6d76952
add comment
am831 64ca2be
add layers
am831 373312c
Model freeze und save classification head
zrthxn 9ee465f
Implement train loop
zrthxn e1ab896
Implement train loop
zrthxn 67dd5fb
calc loss, class head params
am831 be33445
Refactor
zrthxn 70ef27c
fix errors
am831 4be035e
get vector dimensions within classification head
am831 22786a0
hidden_dim
am831 70f93da
log and capture interrupts
am831 8eb5195
dataset prep and plotting loss
am831 0794e02
Language ID Dataloader (#2)
zrthxn c485442
Model Fixes (#3)
zrthxn 7d2c589
get embed_dim dynamically (#4)
am831 d9d6dc0
Model Fixes
zrthxn fc7984c
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 3de26f4
save plot as pkl
am831 caf9a08
address some feedback
zrthxn a928a51
Merge pull request #6 from am831/changes_lanID
am831 1347b46
switch model to train mode
am831 9184204
Code cleanup
zrthxn 842e1c8
BCE loss
zrthxn 19653d7
Remove Label smoothing
zrthxn 588b0a8
change model to increase train loss
am831 480d318
classification head class
am831 fcb9e90
finetune script progress
am831 3581ff9
add comment
am831 3c5adc2
add layers
am831 2c47cde
Model freeze und save classification head
zrthxn c63c427
Implement train loop
zrthxn 2748412
Implement train loop
zrthxn efc93e8
calc loss, class head params
am831 0bed503
Refactor
zrthxn e1f75fd
fix errors
am831 432e692
get vector dimensions within classification head
am831 56f72de
hidden_dim
am831 eb65fa9
log and capture interrupts
am831 04eea4f
dataset prep and plotting loss
am831 f41814f
Language ID Dataloader (#2)
zrthxn 3839e24
Model Fixes (#3)
zrthxn 4d2f435
classification head class
am831 d1ab8ad
finetune script progress
am831 631305f
add comment
am831 6b81e8e
add layers
am831 b266e77
Model freeze und save classification head
zrthxn 2f4b117
Implement train loop
zrthxn 2ce0eca
Implement train loop
zrthxn 228bf5f
calc loss, class head params
am831 b33a1d1
Refactor
zrthxn 2eb1732
fix errors
am831 cbaef16
get vector dimensions within classification head
am831 b5358de
hidden_dim
am831 cc85576
log and capture interrupts
am831 2271edc
dataset prep and plotting loss
am831 e7b2195
Language ID Dataloader (#2)
zrthxn c2488e9
Model Fixes (#3)
zrthxn 9eb5876
get embed_dim dynamically (#4)
am831 7244db5
Model Fixes
zrthxn d2810d8
address some feedback
zrthxn 622675c
save plot as pkl
am831 8980655
switch model to train mode
am831 13bcdea
Code cleanup
zrthxn f86a201
BCE loss
zrthxn e0bffc1
Remove Label smoothing
zrthxn e269cef
change model to increase train loss
am831 27c8c5d
lid classification head training script
mavlyutovr 671746e
fixes
mavlyutovr 71d248c
Merge pull request #8 from zrthxn/lid/ruslan-fixes
am831 6485277
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
157 changes: 157 additions & 0 deletions
157
src/seamless_communication/cli/m4t/classification_head/dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # MIT_LICENSE file in the root directory of this source tree. | ||
|
|
||
| import argparse | ||
| import dataclasses | ||
| import json | ||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| from datasets import load_dataset | ||
| from seamless_communication.datasets.huggingface import ( | ||
| SpeechTokenizer, | ||
| ) | ||
| from seamless_communication.models.unit_extractor import UnitExtractor | ||
|
|
||
| logging.basicConfig( | ||
| level=logging.INFO, | ||
| format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", | ||
| ) | ||
|
|
||
| logger = logging.getLogger("dataset") | ||
|
|
||
| UNITY_TO_COMMON_VOICE_LANG_MAPPING = { | ||
| "eng": "en", | ||
| "ita": "it", | ||
| "afr": "af", | ||
| "asm": "as", | ||
| "bel": "be", | ||
| "bul": "bg", | ||
| "ben": "bn", | ||
| "cat": "ca", | ||
| "ces": "cs", | ||
| "dan": "da", | ||
| "deu": "de", | ||
| "ell": "el", | ||
| "fin": "fi", | ||
| "fra": "fr", | ||
| "glg": "gl", | ||
| "heb": "he", | ||
| "hin": "hi", | ||
| "hrv": "hr", | ||
| "hun": "hu", | ||
| "ind": "id", | ||
| "ibo": "ig", | ||
| "isl": "is", | ||
| "jpn": "ja", | ||
| "jav": "jv", | ||
| "kaz": "kk", | ||
| "kan": "kn", | ||
| "kir": "ky", | ||
| "kor": "ko", | ||
| "lit": "lt", | ||
| "mkd": "mk", | ||
| "mlt": "mt", | ||
| "mya": "my", | ||
| "nld": "nl", | ||
| "pan": "pa", | ||
| "pol": "pl", | ||
| "ron": "ro", | ||
| "rus": "ru", | ||
| "snd": "sd", | ||
| "slk": "sk", | ||
| "spa": "es", | ||
| "srp": "sr", | ||
| "swh": "sw", | ||
| "tam": "ta", | ||
| "tel": "te", | ||
| "tha": "th", | ||
| "tur": "tr", | ||
| "ukr": "uk", | ||
| "urd": "ur", | ||
| "uzn": "uz", | ||
| "vie": "vi", | ||
| "yor": "yo", | ||
| "zul": "zu" | ||
| } | ||
|
|
||
| def _check_lang_code_mapping(lang: str) -> None: | ||
| if lang not in UNITY_TO_COMMON_VOICE_LANG_MAPPING: | ||
| raise ValueError( | ||
| f"No language code mapping for {lang}(M4T)->??(CV). " | ||
| "Please expand `UNITY_TO_COMMON_VOICE_LANG_MAPPING`" | ||
| ) | ||
|
|
||
| class UnitSpeechTokenizer(SpeechTokenizer): | ||
| MODEL_NAME = "xlsr2_1b_v2" | ||
| KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy" | ||
| OUTPUT_LAYER_IDX = 34 | ||
|
|
||
| def __init__(self, device: torch.device): | ||
| super().__init__() | ||
| self.device = device | ||
| self.unit_extractor = UnitExtractor( | ||
| model_name_or_card=self.MODEL_NAME, | ||
| kmeans_uri=self.KMEANS_MODEL_URI, | ||
| device=self.device, | ||
| ) | ||
|
|
||
| def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: | ||
| return self.unit_extractor.predict( | ||
| wav.to(self.device), | ||
| out_layer_idx=self.OUTPUT_LAYER_IDX, | ||
| sample_rate=sample_rate, | ||
| ) | ||
|
|
||
| def download_common_voice(lang: str, split: str, save_directory: str): | ||
| _check_lang_code_mapping(lang) | ||
| dataset = load_dataset('mozilla-foundation/common_voice_17_0', lang, split=split) | ||
| manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json") | ||
| with open(manifest_path, "w") as fp_out: | ||
| for idx, sample in enumerate(dataset, start=1): | ||
| sample['lang'] = lang | ||
| sample['waveform'] = None # already extracted units | ||
| fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n") | ||
| logger.info(f"Saved {idx} samples for split={split} to {manifest_path}") | ||
| logger.info(f"Manifest saved to: {manifest_path}") | ||
|
|
||
| def init_parser() -> argparse.ArgumentParser: | ||
| parser = argparse.ArgumentParser( | ||
| description=( | ||
| "Helper script to download training/evaluation dataset (Common Voice)," | ||
| "extract units from target audio and save the dataset as a manifest " | ||
| "consumable by `finetune.py`." | ||
| ) | ||
| ) | ||
| parser.add_argument( | ||
| "--lang", | ||
| type=str, | ||
| required=True, | ||
| help="Language of the dataset", | ||
| ) | ||
| parser.add_argument( | ||
| "--split", | ||
| type=str, | ||
| required=True, | ||
| help="Dataset split/shard to download (`train`, `validation`, `test`)", | ||
| ) | ||
| parser.add_argument( | ||
| "--save_dir", | ||
| type=Path, | ||
| required=True, | ||
| help="Directory where the datasets will be stored with HuggingFace datasets cache files", | ||
| ) | ||
| return parser | ||
|
|
||
| def main() -> None: | ||
| args = init_parser().parse_args() | ||
| download_common_voice(args.lang, args.split, args.save_dir) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
22 changes: 22 additions & 0 deletions
22
src/seamless_communication/cli/m4t/classification_head/model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from torch import nn | ||
|
|
||
| class ClassificationHead(nn.Module): | ||
| def __init__(self, num_languages, num_layers): | ||
| super(ClassificationHead, self).__init__() | ||
| self.num_languages = num_languages | ||
| self.num_layers = num_layers | ||
| self.hidden_dim = None | ||
| self.input_dim = None | ||
| self.layers = None | ||
|
|
||
| def forward(self, x): | ||
| if self.layers is None: | ||
| self.input_dim = x.size(-1) | ||
| self.hidden_dim = self.input_dim | ||
| self.layers = nn.Sequential( | ||
| nn.Linear(self.input_dim, self.hidden_dim), | ||
| nn.ReLU(), | ||
| nn.Linear(self.hidden_dim, self.num_languages) | ||
| ) | ||
| return self.layers(x) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.