update textgen lstm
This commit is contained in:
161
textgen_lstm.py
161
textgen_lstm.py
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import re, random
|
import re, random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from textgen import textgen
|
from textgen import textgen
|
||||||
@ -16,9 +17,11 @@ install()
|
|||||||
|
|
||||||
class Model(nn.ModuleList):
|
class Model(nn.ModuleList):
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args, device):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
self.batch_size = args["batch_size"]
|
self.batch_size = args["batch_size"]
|
||||||
self.hidden_dim = args["hidden_dim"]
|
self.hidden_dim = args["hidden_dim"]
|
||||||
self.input_size = args["vocab_size"]
|
self.input_size = args["vocab_size"]
|
||||||
@ -26,7 +29,7 @@ class Model(nn.ModuleList):
|
|||||||
self.sequence_len = args["window"]
|
self.sequence_len = args["window"]
|
||||||
|
|
||||||
# Dropout
|
# Dropout
|
||||||
self.dropout = nn.Dropout(0.25)
|
self.dropout = nn.Dropout(0.25) # Don't need to set device for the layers as we transfer the whole model later
|
||||||
|
|
||||||
# Embedding layer
|
# Embedding layer
|
||||||
self.embedding = nn.Embedding(self.input_size, self.hidden_dim, padding_idx=0)
|
self.embedding = nn.Embedding(self.input_size, self.hidden_dim, padding_idx=0)
|
||||||
@ -47,16 +50,16 @@ class Model(nn.ModuleList):
|
|||||||
# Bi-LSTM
|
# Bi-LSTM
|
||||||
# hs = [batch_size x hidden_size]
|
# hs = [batch_size x hidden_size]
|
||||||
# cs = [batch_size x hidden_size]
|
# cs = [batch_size x hidden_size]
|
||||||
hs_forward = torch.zeros(x.size(0), self.hidden_dim)
|
hs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device) # Need to specify device here as this is not part of the model directly
|
||||||
cs_forward = torch.zeros(x.size(0), self.hidden_dim)
|
cs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||||
hs_backward = torch.zeros(x.size(0), self.hidden_dim)
|
hs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||||
cs_backward = torch.zeros(x.size(0), self.hidden_dim)
|
cs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||||
|
|
||||||
# LSTM
|
# LSTM
|
||||||
# hs = [batch_size x (hidden_size * 2)]
|
# hs = [batch_size x (hidden_size * 2)]
|
||||||
# cs = [batch_size x (hidden_size * 2)]
|
# cs = [batch_size x (hidden_size * 2)]
|
||||||
hs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2)
|
hs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
|
||||||
cs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2)
|
cs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
|
||||||
|
|
||||||
# Weights initialization
|
# Weights initialization
|
||||||
torch.nn.init.kaiming_normal_(hs_forward)
|
torch.nn.init.kaiming_normal_(hs_forward)
|
||||||
@ -104,16 +107,43 @@ class LSTMTextGenerator(textgen):
|
|||||||
self.windowsize = windowsize # We slide a window over the character sequence and look at the next letter,
|
self.windowsize = windowsize # We slide a window over the character sequence and look at the next letter,
|
||||||
# similar to the Markov chain order
|
# similar to the Markov chain order
|
||||||
|
|
||||||
|
|
||||||
def init(self, filename):
|
def init(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
|
||||||
# Use this to generate one hot vector and filter characters
|
# Use this to generate one hot vector and filter characters
|
||||||
self.letters = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
|
self.letters = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
|
||||||
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " "]
|
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "ä", "ö", "ü", ".", " "]
|
||||||
|
|
||||||
with open(f"./textfiles/{filename}.txt", "r") as file:
|
with open(f"./textfiles/{filename}.txt", "r") as file:
|
||||||
lines = [line.lower() for line in file.readlines()] # lowercase list
|
lines = [line.lower() for line in file.readlines()] # lowercase list
|
||||||
text = " ".join(lines) # single string
|
text = " ".join(lines) # single string
|
||||||
self.charbase = [char for char in text if char in self.letters] # list of characters
|
self.charbase = [char for char in text if char in self.letters] # list of characters
|
||||||
|
|
||||||
|
# Select device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
dev = "cuda:0"
|
||||||
|
print("Selected GPU for LSTM")
|
||||||
|
else:
|
||||||
|
dev = "cpu"
|
||||||
|
print("Selected CPU for LSTM")
|
||||||
|
self.device = torch.device(dev)
|
||||||
|
|
||||||
|
# Init model
|
||||||
|
self.args = {
|
||||||
|
"window": self.windowsize,
|
||||||
|
"hidden_dim": 128,
|
||||||
|
"vocab_size": len(self.letters),
|
||||||
|
"batch_size": 128,
|
||||||
|
"learning_rate": 0.0005,
|
||||||
|
"num_epochs": 100
|
||||||
|
}
|
||||||
|
self.model = Model(self.args, self.device)
|
||||||
|
self.model.to(self.device) # All model layers need to use the correct tensors (cpu/gpu)
|
||||||
|
|
||||||
|
# Needed for both training and generation
|
||||||
|
self.__generate_char_sequences()
|
||||||
|
|
||||||
# Helper shit
|
# Helper shit
|
||||||
|
|
||||||
def __char_to_idx(self, char):
|
def __char_to_idx(self, char):
|
||||||
@ -148,72 +178,50 @@ class LSTMTextGenerator(textgen):
|
|||||||
|
|
||||||
# Interface shit
|
# Interface shit
|
||||||
|
|
||||||
|
# TODO: Also save/load generated prefixes
|
||||||
def load(self):
|
def load(self):
|
||||||
print(f"Loaded LSTM model with {len(self.charbase)} characters from file.")
|
print(f"Loading \"{self.filename}\" LSTM model with {len(self.charbase)} characters from file.")
|
||||||
|
|
||||||
# TODO: Deduplicate args
|
self.model.load_state_dict(torch.load(f"weights/{self.filename}_lstm_model.pt"))
|
||||||
args = {
|
|
||||||
"window": self.windowsize,
|
|
||||||
"hidden_dim": 128,
|
|
||||||
"vocab_size": len(self.letters),
|
|
||||||
"batch_size": 128,
|
|
||||||
"learning_rate": 0.001,
|
|
||||||
"num_epochs": 50
|
|
||||||
}
|
|
||||||
|
|
||||||
self.model = Model(args)
|
|
||||||
|
|
||||||
# model.load_state_dict(torch.load('weights/kommunistisches_manifest_lstm_model.pt'))
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
print(f"Training LSTM model with {len(self.charbase)} characters.")
|
print(f"Training \"{self.filename}\" LSTM model with {len(self.charbase)} characters.")
|
||||||
|
|
||||||
args = {
|
# Optimizer initialization, RMSprop for RNN
|
||||||
"window": self.windowsize,
|
optimizer = optim.RMSprop(self.model.parameters(), lr=self.args["learning_rate"])
|
||||||
"hidden_dim": 128,
|
|
||||||
"vocab_size": len(self.letters),
|
|
||||||
"batch_size": 128,
|
|
||||||
"learning_rate": 0.001,
|
|
||||||
"num_epochs": 50
|
|
||||||
}
|
|
||||||
|
|
||||||
self.__generate_char_sequences()
|
|
||||||
|
|
||||||
# Model initialization
|
|
||||||
self.model = Model(args)
|
|
||||||
|
|
||||||
# Optimizer initialization
|
|
||||||
optimizer = optim.RMSprop(self.model.parameters(), lr=args["learning_rate"])
|
|
||||||
|
|
||||||
# Defining number of batches
|
# Defining number of batches
|
||||||
num_batches = int(len(self.prefixes) / args["batch_size"])
|
num_batches = int(len(self.prefixes) / self.args["batch_size"])
|
||||||
|
|
||||||
# Set model in training mode
|
# Set model in training mode
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
# Training pahse
|
# Training pahse
|
||||||
for epoch in range(args["num_epochs"]):
|
for epoch in range(self.args["num_epochs"]):
|
||||||
|
|
||||||
# Mini batches
|
# Mini batches
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
|
|
||||||
# Batch definition
|
# Batch definition
|
||||||
try:
|
try:
|
||||||
x_batch = self.prefixes[i * args["batch_size"] : (i + 1) * args["batch_size"]]
|
x_batch = self.prefixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
|
||||||
y_batch = self.suffixes[i * args["batch_size"] : (i + 1) * args["batch_size"]]
|
y_batch = self.suffixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
|
||||||
except:
|
except:
|
||||||
x_batch = self.prefixes[i * args["batch_size"] :]
|
x_batch = self.prefixes[i * self.args["batch_size"]:]
|
||||||
y_batch = self.suffixes[i * args["batch_size"] :]
|
y_batch = self.suffixes[i * self.args["batch_size"]:]
|
||||||
|
|
||||||
# Convert numpy array into torch tensors
|
# Convert numpy array into torch tensors
|
||||||
x = torch.from_numpy(x_batch).type(torch.long)
|
x = torch.from_numpy(x_batch).type(torch.long).to(self.device)
|
||||||
y = torch.from_numpy(y_batch).type(torch.long)
|
y = torch.from_numpy(y_batch).type(torch.long).to(self.device)
|
||||||
|
|
||||||
# Feed the model
|
# Feed the model
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
|
|
||||||
# Loss calculation
|
# Loss calculation
|
||||||
loss = F.cross_entropy(y_pred, y.squeeze())
|
loss = F.cross_entropy(y_pred, y.squeeze()).to(self.device)
|
||||||
|
losses += [loss.item()]
|
||||||
|
|
||||||
# Clean gradients
|
# Clean gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -226,35 +234,44 @@ class LSTMTextGenerator(textgen):
|
|||||||
|
|
||||||
print("Epoch: %d , loss: %.5f " % (epoch, loss.item()))
|
print("Epoch: %d , loss: %.5f " % (epoch, loss.item()))
|
||||||
|
|
||||||
torch.save(self.model.state_dict(), 'weights/kommunistisches_manifest_lstm_model.pt')
|
torch.save(self.model.state_dict(), f"weights/{self.filename}_lstm_model.pt")
|
||||||
|
print(f"Saved \"{self.filename}\" LSTM model to file")
|
||||||
|
|
||||||
|
plt.plot(np.arange(0, len(losses)), losses)
|
||||||
|
plt.title(self.filename)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def generate_sentence(self):
|
def generate_sentence(self):
|
||||||
|
# Randomly is selected the index from the set of sequences
|
||||||
|
start = np.random.randint(0, len(self.prefixes)-1)
|
||||||
|
|
||||||
|
# Convert back to string to match complete_sentence
|
||||||
|
pattern = "".join([self.__idx_to_char(char) for char in self.prefixes[start]]) # random sequence from the training text
|
||||||
|
|
||||||
|
return self.complete_sentence(pattern)
|
||||||
|
|
||||||
|
def complete_sentence(self, prefix):
|
||||||
|
print("Prefix:", prefix)
|
||||||
|
|
||||||
|
# Convert to indexes np.array
|
||||||
|
pattern = np.array([self.__char_to_idx(char) for char in prefix])
|
||||||
|
|
||||||
# Set the model in evalulation mode
|
# Set the model in evalulation mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Define the softmax function
|
# Define the softmax function
|
||||||
softmax = nn.Softmax(dim=1)
|
softmax = nn.Softmax(dim=1).to(self.device)
|
||||||
|
|
||||||
# Randomly is selected the index from the set of sequences
|
|
||||||
start = np.random.randint(0, len(self.prefixes)-1)
|
|
||||||
|
|
||||||
# The pattern is defined given the random idx
|
|
||||||
pattern = self.prefixes[start]
|
|
||||||
|
|
||||||
# By making use of the dictionaries, it is printed the pattern
|
|
||||||
print("\nPattern: \n")
|
|
||||||
print(''.join([self.__idx_to_char(value) for value in pattern]), "\"")
|
|
||||||
|
|
||||||
# In full_prediction we will save the complete prediction
|
# In full_prediction we will save the complete prediction
|
||||||
full_prediction = pattern.copy()
|
full_prediction = pattern.copy()
|
||||||
|
|
||||||
# the prediction starts, it is going to be predicted a given
|
print("Generating sentence...")
|
||||||
# number of characters
|
|
||||||
for _ in range(250):
|
|
||||||
|
|
||||||
|
# Predic the next characters one by one, append chars to the starting pattern until . is reached, max 500 iterations
|
||||||
|
for _ in range(500):
|
||||||
# the numpy patterns is transformed into a tesor-type and reshaped
|
# the numpy patterns is transformed into a tesor-type and reshaped
|
||||||
pattern = torch.from_numpy(pattern).type(torch.long)
|
pattern = torch.from_numpy(pattern).type(torch.long).to(self.device)
|
||||||
pattern = pattern.view(1,-1)
|
pattern = pattern.view(1,-1)
|
||||||
|
|
||||||
# make a prediction given the pattern
|
# make a prediction given the pattern
|
||||||
@ -263,12 +280,12 @@ class LSTMTextGenerator(textgen):
|
|||||||
prediction = softmax(prediction)
|
prediction = softmax(prediction)
|
||||||
|
|
||||||
# the prediction tensor is transformed into a numpy array
|
# the prediction tensor is transformed into a numpy array
|
||||||
prediction = prediction.squeeze().detach().numpy()
|
prediction = prediction.squeeze().detach().cpu().numpy()
|
||||||
# it is taken the idx with the highest probability
|
# it is taken the idx with the highest probability
|
||||||
arg_max = np.argmax(prediction)
|
arg_max = np.argmax(prediction)
|
||||||
|
|
||||||
# the current pattern tensor is transformed into numpy array
|
# the current pattern tensor is transformed into numpy array
|
||||||
pattern = pattern.squeeze().detach().numpy()
|
pattern = pattern.squeeze().detach().cpu().numpy()
|
||||||
# the window is sliced 1 character to the right
|
# the window is sliced 1 character to the right
|
||||||
pattern = pattern[1:]
|
pattern = pattern[1:]
|
||||||
# the new pattern is composed by the "old" pattern + the predicted character
|
# the new pattern is composed by the "old" pattern + the predicted character
|
||||||
@ -277,8 +294,10 @@ class LSTMTextGenerator(textgen):
|
|||||||
# the full prediction is saved
|
# the full prediction is saved
|
||||||
full_prediction = np.append(full_prediction, arg_max)
|
full_prediction = np.append(full_prediction, arg_max)
|
||||||
|
|
||||||
print("prediction: \n")
|
# Stop on . character
|
||||||
print(''.join([self.__idx_to_char(value) for value in full_prediction]), "\"")
|
if self.__idx_to_char(arg_max) == ".":
|
||||||
|
break
|
||||||
|
|
||||||
def complete_sentence(self, prefix):
|
full_prediction = "".join([self.__idx_to_char(value) for value in full_prediction])
|
||||||
pass
|
print("Generated:", full_prediction)
|
||||||
|
return full_prediction
|
||||||
|
Reference in New Issue
Block a user