enable lstm in bot
This commit is contained in:
21
bot.py
21
bot.py
@ -9,6 +9,7 @@ from typing import Optional, Union
|
||||
|
||||
from textgen import textgen
|
||||
from textgen_markov import MarkovTextGenerator
|
||||
from textgen_lstm import LSTMTextGenerator
|
||||
|
||||
# TODO: Reenable and extend scraper
|
||||
# from models import Models
|
||||
@ -57,17 +58,27 @@ class HeidiClient(discord.Client):
|
||||
# Textgen
|
||||
self.textgen_models: dict[str, textgen] = {
|
||||
# The name must correspond to the name of the training text file
|
||||
"bibel": MarkovTextGenerator(3), # Prefix length of 3
|
||||
"kommunistisches_manifest": MarkovTextGenerator(3),
|
||||
"musk": MarkovTextGenerator(3)
|
||||
"kommunistisches_manifest": LSTMTextGenerator(10),
|
||||
# "musk": LSTMTextGenerator(10),
|
||||
# "bibel": LSTMTextGenerator(10)
|
||||
|
||||
# "bibel": MarkovTextGenerator(3), # Prefix length of 3
|
||||
# "kommunistisches_manifest": MarkovTextGenerator(3),
|
||||
# "musk": MarkovTextGenerator(3)
|
||||
}
|
||||
|
||||
for name, model in self.textgen_models.items():
|
||||
model.init(name) # Loads the textfile
|
||||
if DOCKER:
|
||||
|
||||
if os.path.exists(f"weights/{name}_lstm_model.pt"):
|
||||
model.load()
|
||||
else:
|
||||
elif not DOCKER:
|
||||
model.train()
|
||||
else:
|
||||
print("Error: Can't load model", name)
|
||||
|
||||
print("Generating test sentence for", name)
|
||||
self.textgen_models[name].generate_sentence()
|
||||
|
||||
# Synchronize commands to guilds
|
||||
async def setup_hook(self):
|
||||
|
Reference in New Issue
Block a user