|
| 1 | +import nltk |
| 2 | +from nltk.corpus import stopwords |
| 3 | +from nltk.tokenize import word_tokenize |
| 4 | +from nltk.stem import WordNetLemmatizer |
| 5 | +import re |
| 6 | +from transformers import T5ForConditionalGeneration, T5Tokenizer |
| 7 | +from transformers import Trainer, TrainingArguments |
| 8 | +from datasets import load_dataset |
| 9 | + |
| 10 | +# Download NLTK resources |
| 11 | +nltk.download('punkt') |
| 12 | +nltk.download('stopwords') |
| 13 | +nltk.download('wordnet') |
| 14 | +nltk.download('averaged_perceptron_tagger') |
| 15 | + |
| 16 | +# Initialize the lemmatizer and stop words list |
| 17 | +lemmatizer = WordNetLemmatizer() |
| 18 | +stop_words = set(stopwords.words('english')) |
| 19 | + |
| 20 | +def preprocess_text(text): |
| 21 | + # Lowercase the text |
| 22 | + text = text.lower() |
| 23 | + # Remove punctuation |
| 24 | + text = re.sub(r'[^\w\s]', '', text) |
| 25 | + # Tokenize text |
| 26 | + tokens = word_tokenize(text) |
| 27 | + # Lemmatize and remove stop words |
| 28 | + tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words] |
| 29 | + return tokens |
| 30 | + |
| 31 | + |
| 32 | +# Load pre-trained T5 model and tokenizer |
| 33 | +model = T5ForConditionalGeneration.from_pretrained('t5-small') |
| 34 | +tokenizer = T5Tokenizer.from_pretrained('t5-small') |
| 35 | + |
| 36 | +dataset = load_dataset("bookcorpus", split="train") # For BooksCorpus |
| 37 | +wiki_dataset = load_dataset("wikipedia", "20220301.en", split="train") # For Wikipedia |
| 38 | + |
| 39 | +# Define a training function |
| 40 | +def train_model(dataset): |
| 41 | + # Tokenize inputs and outputs |
| 42 | + inputs = tokenizer(["correct: " + text for text in dataset["input_texts"]], return_tensors="pt", padding=True) |
| 43 | + outputs = tokenizer(["grammar_corrected: " + text for text in dataset["output_texts"]], return_tensors="pt", padding=True) |
| 44 | + |
| 45 | + # Define Trainer |
| 46 | + training_args = TrainingArguments( |
| 47 | + output_dir='./results', |
| 48 | + per_device_train_batch_size=4, |
| 49 | + num_train_epochs=3, |
| 50 | + weight_decay=0.01, |
| 51 | + ) |
| 52 | + trainer = Trainer( |
| 53 | + model=model, |
| 54 | + args=training_args, |
| 55 | + train_dataset=dataset |
| 56 | + ) |
| 57 | + |
| 58 | + trainer.train() |
| 59 | + |
| 60 | +# Train the model on the processed dataset |
| 61 | +train_model(dataset) |
| 62 | + |
| 63 | + |
| 64 | + |
| 65 | +def correct_grammar(text): |
| 66 | + input_text = "correct: " + text |
| 67 | + input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| 68 | + outputs = model.generate(input_ids) |
| 69 | + corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| 70 | + return corrected_text |
| 71 | + |
| 72 | +# Example usage |
| 73 | +test_sentence = "She go to the market every morning." |
| 74 | +print("Corrected Sentence:", correct_grammar(test_sentence)) |
0 commit comments