Compare commits
8 Commits
9b66061ee7
...
v3.1
| Author | SHA1 | Date | |
|---|---|---|---|
| c2847de7dd | |||
| 08230eb3de | |||
| f2ddb4ab66 | |||
| 876232f674 | |||
| d7c3a7c740 | |||
| bdcd5208a7 | |||
| 79fcf0142a | |||
| 0f6cc12182 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,3 +12,4 @@ Pipfile.lock
|
||||
/disabled_voicelines/
|
||||
*.svg
|
||||
.vscode
|
||||
Heidi_User.conf
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
workflow: # for entire pipeline
|
||||
rules:
|
||||
- if: '$CI_COMMIT_REF_NAME == "master"' # only run on master...
|
||||
changes: # ...and when these files have changed
|
||||
- "*.py"
|
||||
- "Dockerfile"
|
||||
|
||||
docker-build:
|
||||
stage: build
|
||||
image: docker:20 # provides the docker toolset (but without an active daemon)
|
||||
services: # configure images that run during jobs linked to the image (above)
|
||||
- docker:dind # dind build on docker and starts up the dockerdaemon (docker itself doesn't do that), which is needed to call docker build etc.
|
||||
before_script:
|
||||
- docker login -u $CI_REGISTRY_USER -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
|
||||
script:
|
||||
- docker pull $CI_REGISTRY_IMAGE:latest || true # latest image for cache (not failing if image is not found)
|
||||
- >
|
||||
docker build
|
||||
--pull
|
||||
--cache-from $CI_REGISTRY_IMAGE:latest
|
||||
--label "org.opencontainers.image.title=$CI_PROJECT_TITLE"
|
||||
--label "org.opencontainers.image.url=$CI_PROJECT_URL"
|
||||
--label "org.opencontainers.image.created=$CI_JOB_STARTED_AT"
|
||||
--label "org.opencontainers.image.revision=$CI_COMMIT_SHA"
|
||||
--label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME"
|
||||
--tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
|
||||
.
|
||||
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA $CI_REGISTRY_IMAGE:latest
|
||||
- docker push $CI_REGISTRY_IMAGE:latest
|
||||
|
||||
docker-deploy:
|
||||
stage: deploy
|
||||
image: alpine:3.15
|
||||
needs: ["docker-build"]
|
||||
script:
|
||||
- chmod og= $ID_RSA
|
||||
- apk update && apk add openssh-client
|
||||
- ssh -i $ID_RSA -o StrictHostKeyChecking=no $SERVER_USER@$SERVER_IP "/home/christoph/$CI_PROJECT_TITLE/launch.sh"
|
||||
@ -1,2 +0,0 @@
|
||||
[ENTRANCE.SOUND]
|
||||
|
||||
72
bot.py
72
bot.py
@ -120,8 +120,7 @@ def user_entrance_sound_autocomplete(
|
||||
"""
|
||||
boards: List[str] = os.listdir(SOUNDDIR)
|
||||
all_sounds: Dict[str, List[str]] = {
|
||||
board: list(map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/")))
|
||||
for board in boards
|
||||
board: os.listdir(f"{SOUNDDIR}/{board}/") for board in boards
|
||||
} # These are all sounds, organized per board
|
||||
|
||||
# @todo Initially only suggest boards, because there are too many sounds to show them all
|
||||
@ -133,7 +132,7 @@ def user_entrance_sound_autocomplete(
|
||||
for sound in board_sounds: # Iterate over board specific sounds
|
||||
soundpath = f"{board}/{sound}"
|
||||
if soundpath.lower().startswith(current.lower()):
|
||||
completions += [Choice(name=soundpath, value=soundpath)]
|
||||
completions += [Choice(name=soundpath.split(".")[0], value=soundpath)]
|
||||
|
||||
return completions
|
||||
|
||||
@ -150,6 +149,7 @@ def user_entrance_sound_autocomplete(
|
||||
config_value="Der Wert, auf welche die Option gesetzt werden soll."
|
||||
)
|
||||
@app_commands.autocomplete(config_value=user_config_value_autocomplete)
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def user_config(
|
||||
interaction: Interaction, config_key: str, config_value: str
|
||||
) -> None:
|
||||
@ -159,7 +159,7 @@ async def user_config(
|
||||
# Only Members can set settings
|
||||
if not isinstance(interaction.user, Member):
|
||||
print("User not a member")
|
||||
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!")
|
||||
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!", ephemeral=True)
|
||||
return
|
||||
|
||||
member: Member = interaction.user
|
||||
@ -168,7 +168,8 @@ async def user_config(
|
||||
client.write_user_config()
|
||||
|
||||
await interaction.response.send_message(
|
||||
f"Ok, ich schreibe {member.name}={config_value} in mein fettes Gehirn!"
|
||||
f"Ok, ich schreibe {member.name}={config_value} in mein fettes Gehirn!",
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
|
||||
@ -176,6 +177,7 @@ async def user_config(
|
||||
|
||||
|
||||
@client.tree.command(name="heidi", description="Heidi!")
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def heidi_exclaim(interaction: Interaction) -> None:
|
||||
"""
|
||||
Print a random Heidi quote.
|
||||
@ -188,6 +190,10 @@ async def heidi_exclaim(interaction: Interaction) -> None:
|
||||
"Warum denn so schüchtern?",
|
||||
"Im TV ist das legal!",
|
||||
"Das Stroh ist nur fürs Shooting!",
|
||||
"Jetzt sei doch mal sexy!",
|
||||
"Stell dich nicht so an!",
|
||||
"Models müssen da halt durch!",
|
||||
"Heul doch nicht!",
|
||||
]
|
||||
await interaction.response.send_message(random.choice(messages))
|
||||
|
||||
@ -195,6 +201,7 @@ async def heidi_exclaim(interaction: Interaction) -> None:
|
||||
@client.tree.command(name="miesmuschel", description="Was denkt Heidi?")
|
||||
@app_commands.rename(question="frage")
|
||||
@app_commands.describe(question="Heidi wird es beantworten!")
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def magic_shell(interaction: Interaction, question: str) -> None:
|
||||
"""
|
||||
Answer a yes/no question.
|
||||
@ -208,14 +215,13 @@ async def magic_shell(interaction: Interaction, question: str) -> None:
|
||||
"Klaro Karo",
|
||||
"Offensichtlich Sherlock",
|
||||
"Tom sagt Ja",
|
||||
|
||||
"Nein!",
|
||||
"Nö.",
|
||||
"Nä.",
|
||||
"Niemals!",
|
||||
"Nur über meine Leiche du Hurensohn!",
|
||||
"In deinen Träumen.",
|
||||
"Tom sagt Nein"
|
||||
"Tom sagt Nein",
|
||||
]
|
||||
question = question.strip()
|
||||
question_mark = "" if question[-1] == "?" else "?"
|
||||
@ -230,6 +236,7 @@ async def magic_shell(interaction: Interaction, question: str) -> None:
|
||||
@app_commands.describe(option_a="Ist es vielleicht dies?")
|
||||
@app_commands.rename(option_b="oder")
|
||||
@app_commands.describe(option_b="Oder doch eher das?")
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def choose(interaction: Interaction, option_a: str, option_b: str) -> None:
|
||||
"""
|
||||
Select an answer from two options.
|
||||
@ -265,12 +272,10 @@ async def sound_autocomplete(
|
||||
Suggest a sound from an already selected board.
|
||||
"""
|
||||
board: str = interaction.namespace.board
|
||||
sounds: List[str] = list(
|
||||
map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/"))
|
||||
)
|
||||
sounds: List[str] = os.listdir(f"{SOUNDDIR}/{board}/")
|
||||
|
||||
return [
|
||||
Choice(name=sound, value=sound)
|
||||
Choice(name=sound.split(".")[0], value=sound)
|
||||
for sound in sounds
|
||||
if sound.lower().startswith(current.lower())
|
||||
]
|
||||
@ -282,6 +287,7 @@ async def sound_autocomplete(
|
||||
@app_commands.describe(sound="Was soll Heidi sagen?")
|
||||
@app_commands.autocomplete(board=board_autocomplete)
|
||||
@app_commands.autocomplete(sound=sound_autocomplete)
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def say_voiceline(interaction: Interaction, board: str, sound: str) -> None:
|
||||
"""
|
||||
Play a voiceline in the calling member's current voice channel.
|
||||
@ -289,7 +295,7 @@ async def say_voiceline(interaction: Interaction, board: str, sound: str) -> Non
|
||||
# Only Members can access voice channels
|
||||
if not isinstance(interaction.user, Member):
|
||||
print("User not a member")
|
||||
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!")
|
||||
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!", ephemeral=True)
|
||||
return
|
||||
|
||||
member: Member = interaction.user
|
||||
@ -297,6 +303,43 @@ async def say_voiceline(interaction: Interaction, board: str, sound: str) -> Non
|
||||
await play_voice_line_for_member(interaction, member, board, sound)
|
||||
|
||||
|
||||
class InstantButton(discord.ui.Button):
|
||||
def __init__(self, label: str, board: str, sound: str):
|
||||
super().__init__(style=discord.ButtonStyle.red, label=label)
|
||||
|
||||
self.board = board
|
||||
self.sound = sound
|
||||
|
||||
async def callback(self, interaction: Interaction):
|
||||
"""
|
||||
Handle a press of the button.
|
||||
"""
|
||||
if not isinstance(interaction.user, Member):
|
||||
await interaction.response.send_message("Heidi mag keine discord.User, nur discord.Member!", ephemeral=True)
|
||||
return
|
||||
|
||||
await play_voice_line_for_member(interaction, interaction.user, self.board, self.sound)
|
||||
|
||||
|
||||
class InstantButtons(discord.ui.View):
|
||||
def __init__(self, board: str, timeout=None):
|
||||
super().__init__(timeout=timeout)
|
||||
|
||||
sounds = os.listdir(f"{SOUNDDIR}/{board}")
|
||||
for sound in sounds:
|
||||
self.add_item(InstantButton(sound.split(".")[0], board, sound))
|
||||
|
||||
|
||||
@client.tree.command(
|
||||
name="instantbuttons", description="Heidi malt Knöpfe für Sounds in den Chat."
|
||||
)
|
||||
@app_commands.describe(board="Welches Soundboard soll knöpfe bekommen?")
|
||||
@app_commands.autocomplete(board=board_autocomplete)
|
||||
@enforce_channel(HEIDI_SPAM_ID)
|
||||
async def soundboard_buttons(interaction: Interaction, board: str) -> None:
|
||||
await interaction.response.send_message(f"Soundboard: {board.capitalize()}", view=InstantButtons(board))
|
||||
|
||||
|
||||
# Contextmenu ------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -313,7 +356,7 @@ async def insult(
|
||||
|
||||
if not member.dm_channel:
|
||||
print("Error creating DMChannel!")
|
||||
await interaction.response.send_message("Heidi sagt: Gib mal DM Nummer süße*r!")
|
||||
await interaction.response.send_message("Heidi sagt: Gib mal DM Nummer süße*r!", ephemeral=True)
|
||||
return
|
||||
|
||||
insults = [
|
||||
@ -334,7 +377,8 @@ async def insult(
|
||||
|
||||
await member.dm_channel.send(random.choice(insults))
|
||||
await interaction.response.send_message(
|
||||
"Anzeige ist raus!"
|
||||
"Anzeige ist raus!",
|
||||
ephemeral=True
|
||||
) # with ephemeral = True only the caller can see the answer
|
||||
|
||||
|
||||
|
||||
@ -115,9 +115,20 @@ class HeidiClient(discord.Client):
|
||||
after: VoiceState,
|
||||
) -> None:
|
||||
"""
|
||||
Play a sound when a member joins a voice channel.
|
||||
Play a sound when a member joins a voice channel (and another member is present).
|
||||
This function is set in on_voice_state_triggers and triggered by the on_voice_state_update event.
|
||||
"""
|
||||
|
||||
# Don't play anything when no other users are present
|
||||
if (
|
||||
member is not None
|
||||
and member.voice is not None
|
||||
and member.voice.channel is not None
|
||||
and len(member.voice.channel.members) <= 1
|
||||
):
|
||||
print("Not playing entrance sound, as no other members are present")
|
||||
return
|
||||
|
||||
soundpath: Union[str, None] = self.user_config["ENTRANCE.SOUND"].get(
|
||||
member.name, None
|
||||
)
|
||||
|
||||
@ -24,3 +24,6 @@ SOUNDDIR: str = "/sounds" if DOCKER else "./heidi-sounds"
|
||||
# IDs of the servers Heidi is used on
|
||||
LINUS_GUILD = discord.Object(id=431154792308408340)
|
||||
TEST_GUILD = discord.Object(id=821511861178204161)
|
||||
|
||||
# Channel IDs
|
||||
HEIDI_SPAM_ID = 822223476101742682
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Union
|
||||
|
||||
import discord
|
||||
@ -8,6 +9,40 @@ from heidi_constants import *
|
||||
|
||||
print("Debug: Importing heidi_helpers.py")
|
||||
|
||||
|
||||
# Checks -----------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# 1. @enforce_channel(ID) is added to a function, which evaluates to decorate with the channel_id in its closure
|
||||
# 2. The function is passed to decorate(function),
|
||||
def enforce_channel(channel_id):
|
||||
"""
|
||||
Only run a function if called from the correct channel.
|
||||
"""
|
||||
def decorate(function):
|
||||
|
||||
@functools.wraps(function)
|
||||
async def wrapped(*args, **kwargs):
|
||||
"""
|
||||
Sends an interaction response if the interaction is not triggered from the heidi_spam channel.
|
||||
"""
|
||||
interaction: Interaction = args[0]
|
||||
|
||||
# Do not call the decorated function if the channel_id doesn't match
|
||||
if not interaction.channel_id == channel_id:
|
||||
await interaction.response.send_message("Heidi sagt: Geh in heidi_spam du dulli", ephemeral=True)
|
||||
return
|
||||
|
||||
await function(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
# Sounds -----------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# @todo Normalize volume when playing
|
||||
async def play_voice_line(
|
||||
interaction: Union[Interaction, None],
|
||||
@ -19,20 +54,21 @@ async def play_voice_line(
|
||||
Play a voice line in the specified channel.
|
||||
"""
|
||||
try:
|
||||
open(f"{SOUNDDIR}/{board}/{sound}.mkv")
|
||||
open(f"{SOUNDDIR}/{board}/{sound}")
|
||||
except IOError:
|
||||
print("Error: Invalid soundfile!")
|
||||
print(f"Error: Invalid soundfile {SOUNDDIR}/{board}/{sound}!")
|
||||
if interaction is not None:
|
||||
await interaction.response.send_message(
|
||||
f'Heidi sagt: "{board}/{sound}" kanninich finden bruder'
|
||||
f'Heidi sagt: "{board}/{sound}" kanninich finden bruder',
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
if interaction is not None:
|
||||
await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"')
|
||||
await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"', ephemeral=True)
|
||||
|
||||
audio_source = discord.FFmpegPCMAudio(
|
||||
f"{SOUNDDIR}/{board}/{sound}.mkv"
|
||||
f"{SOUNDDIR}/{board}/{sound}"
|
||||
) # only works from docker
|
||||
voice_client = await voice_channel.connect()
|
||||
voice_client.play(audio_source)
|
||||
@ -61,7 +97,7 @@ async def play_voice_line_for_member(
|
||||
):
|
||||
print("User not in (valid) voice channel!")
|
||||
if interaction is not None:
|
||||
await interaction.response.send_message("Heidi sagt: Komm in den Channel!")
|
||||
await interaction.response.send_message("Heidi sagt: Komm in den Channel!", ephemeral=True)
|
||||
return
|
||||
|
||||
voice_channel: VoiceChannel = member.voice.channel
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
cd /home/christoph/HeidiBot
|
||||
git pull
|
||||
|
||||
docker pull registry.gitlab.com/churl/heidibot
|
||||
docker container rm -f heidibot
|
||||
docker run -d --env-file /home/christoph/HeidiBot/.env --mount src=/home/christoph/HeidiBot/voicelines,target=/sounds,type=bind --name heidibot registry.gitlab.com/churl/heidibot
|
||||
docker image prune -f
|
||||
33
models.py
33
models.py
@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
class Models:
|
||||
def __init__(self):
|
||||
url_girls = "https://www.prosieben.de/tv/germanys-next-topmodel/models"
|
||||
|
||||
html_girls = requests.get(url_girls)
|
||||
soup_girls = BeautifulSoup(html_girls.text, "html.parser")
|
||||
|
||||
girls_in = soup_girls.findAll("a", class_="candidate-in")
|
||||
girls_out = soup_girls.findAll("a", class_="candidate-out")
|
||||
|
||||
self.girls_in = {girl.get("title").lower(): girl for girl in girls_in}
|
||||
self.girls_out = {girl.get("title").lower(): girl for girl in girls_out}
|
||||
|
||||
self.girls = {**self.girls_in, **self.girls_out}
|
||||
|
||||
def get_in_names(self):
|
||||
return self.girls_in.keys()
|
||||
|
||||
def get_out_names(self):
|
||||
return self.girls_out.keys()
|
||||
|
||||
def get_image(self, name):
|
||||
style = self.girls[name.lower()].find("figure", class_="teaser-img").get("style")
|
||||
url = re.search(r"url\(.*\);", style).group()
|
||||
|
||||
return url[4:-9] + "562x996" # increase resolution
|
||||
@ -3,12 +3,3 @@ rich
|
||||
discord.py # maintained again
|
||||
pynacl # voice support
|
||||
python-dotenv # discord guild secrets
|
||||
|
||||
# Webscraping
|
||||
# requests
|
||||
# beautifulsoup4
|
||||
|
||||
# Textgeneration
|
||||
# torch
|
||||
# numpy
|
||||
# nltk
|
||||
|
||||
44
textgen.py
44
textgen.py
@ -1,44 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from rich.traceback import install
|
||||
install()
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# In Python it is generally not needed to use abstract classes, but I wanted to do it safely
|
||||
|
||||
class textgen(ABC):
|
||||
@abstractmethod
|
||||
def init(self, filename):
|
||||
"""
|
||||
filename - The file (same directory as textgen.py) that contains the training text
|
||||
"""
|
||||
raise NotImplementedError("Can't use abstract class")
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
"""
|
||||
Load the trained markov chain from a precomputed file
|
||||
"""
|
||||
raise NotImplementedError("Can't use abstract class")
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Generate the markov chain, uses prefix length defined in init()
|
||||
"""
|
||||
raise NotImplementedError("Can't use abstract class")
|
||||
|
||||
@abstractmethod
|
||||
def generate_sentence(self):
|
||||
"""
|
||||
Generate a series of words/characters until a . is generated
|
||||
"""
|
||||
raise NotImplementedError("Can't use abstract class")
|
||||
|
||||
@abstractmethod
|
||||
def complete_sentence(self, prefix):
|
||||
"""
|
||||
Generate the rest of a sentence for a given beginning
|
||||
"""
|
||||
raise NotImplementedError("Can't use abstract class")
|
||||
303
textgen_lstm.py
303
textgen_lstm.py
@ -1,303 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re, random
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from textgen import textgen
|
||||
from torch import nn, optim
|
||||
|
||||
from rich.traceback import install
|
||||
install()
|
||||
|
||||
# Model =======================================================================================
|
||||
# https://towardsdatascience.com/text-generation-with-bi-lstm-in-pytorch-5fda6e7cc22c
|
||||
# Embedding -> Bi-LSTM -> LSTM -> Linear
|
||||
|
||||
class Model(nn.ModuleList):
|
||||
|
||||
def __init__(self, args, device):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
self.batch_size = args["batch_size"]
|
||||
self.hidden_dim = args["hidden_dim"]
|
||||
self.input_size = args["vocab_size"]
|
||||
self.num_classes = args["vocab_size"]
|
||||
self.sequence_len = args["window"]
|
||||
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(0.25) # Don't need to set device for the layers as we transfer the whole model later
|
||||
|
||||
# Embedding layer
|
||||
self.embedding = nn.Embedding(self.input_size, self.hidden_dim, padding_idx=0)
|
||||
|
||||
# Bi-LSTM
|
||||
# Forward and backward
|
||||
self.lstm_cell_forward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
|
||||
self.lstm_cell_backward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
|
||||
|
||||
# LSTM layer
|
||||
self.lstm_cell = nn.LSTMCell(self.hidden_dim * 2, self.hidden_dim * 2)
|
||||
|
||||
# Linear layer
|
||||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# Bi-LSTM
|
||||
# hs = [batch_size x hidden_size]
|
||||
# cs = [batch_size x hidden_size]
|
||||
hs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device) # Need to specify device here as this is not part of the model directly
|
||||
cs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||
hs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||
cs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
|
||||
|
||||
# LSTM
|
||||
# hs = [batch_size x (hidden_size * 2)]
|
||||
# cs = [batch_size x (hidden_size * 2)]
|
||||
hs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
|
||||
cs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
|
||||
|
||||
# Weights initialization
|
||||
torch.nn.init.kaiming_normal_(hs_forward)
|
||||
torch.nn.init.kaiming_normal_(cs_forward)
|
||||
torch.nn.init.kaiming_normal_(hs_backward)
|
||||
torch.nn.init.kaiming_normal_(cs_backward)
|
||||
torch.nn.init.kaiming_normal_(hs_lstm)
|
||||
torch.nn.init.kaiming_normal_(cs_lstm)
|
||||
|
||||
# From idx to embedding
|
||||
out = self.embedding(x)
|
||||
|
||||
# Prepare the shape for LSTM Cells
|
||||
out = out.view(self.sequence_len, x.size(0), -1)
|
||||
|
||||
forward = []
|
||||
backward = []
|
||||
|
||||
# Unfolding Bi-LSTM
|
||||
# Forward
|
||||
for i in range(self.sequence_len):
|
||||
hs_forward, cs_forward = self.lstm_cell_forward(out[i], (hs_forward, cs_forward))
|
||||
forward.append(hs_forward)
|
||||
|
||||
# Backward
|
||||
for i in reversed(range(self.sequence_len)):
|
||||
hs_backward, cs_backward = self.lstm_cell_backward(out[i], (hs_backward, cs_backward))
|
||||
backward.append(hs_backward)
|
||||
|
||||
# LSTM
|
||||
for fwd, bwd in zip(forward, backward):
|
||||
input_tensor = torch.cat((fwd, bwd), 1)
|
||||
hs_lstm, cs_lstm = self.lstm_cell(input_tensor, (hs_lstm, cs_lstm))
|
||||
|
||||
# Last hidden state is passed through a linear layer
|
||||
out = self.linear(hs_lstm)
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================================
|
||||
|
||||
class LSTMTextGenerator(textgen):
|
||||
|
||||
def __init__(self, windowsize):
|
||||
self.windowsize = windowsize # We slide a window over the character sequence and look at the next letter,
|
||||
# similar to the Markov chain order
|
||||
|
||||
|
||||
def init(self, filename):
|
||||
self.filename = filename
|
||||
|
||||
# Use this to generate one hot vector and filter characters
|
||||
self.letters = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
|
||||
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "ä", "ö", "ü", ".", " "]
|
||||
|
||||
with open(f"./textfiles/{filename}.txt", "r") as file:
|
||||
lines = [line.lower() for line in file.readlines()] # lowercase list
|
||||
text = " ".join(lines) # single string
|
||||
self.charbase = [char for char in text if char in self.letters] # list of characters
|
||||
|
||||
# Select device
|
||||
if torch.cuda.is_available():
|
||||
dev = "cuda:0"
|
||||
print("Selected GPU for LSTM")
|
||||
else:
|
||||
dev = "cpu"
|
||||
print("Selected CPU for LSTM")
|
||||
self.device = torch.device(dev)
|
||||
|
||||
# Init model
|
||||
self.args = {
|
||||
"window": self.windowsize,
|
||||
"hidden_dim": 128,
|
||||
"vocab_size": len(self.letters),
|
||||
"batch_size": 128,
|
||||
"learning_rate": 0.0005,
|
||||
"num_epochs": 100
|
||||
}
|
||||
self.model = Model(self.args, self.device)
|
||||
self.model.to(self.device) # All model layers need to use the correct tensors (cpu/gpu)
|
||||
|
||||
# Needed for both training and generation
|
||||
self.__generate_char_sequences()
|
||||
|
||||
# Helper shit
|
||||
|
||||
def __char_to_idx(self, char):
|
||||
return self.letters.index(char)
|
||||
|
||||
def __idx_to_char(self, idx):
|
||||
return self.letters[idx]
|
||||
|
||||
def __generate_char_sequences(self):
|
||||
# Example
|
||||
# [[21, 20, 15],
|
||||
# [12, 12, 14]]
|
||||
prefixes = []
|
||||
|
||||
# Example
|
||||
# [[1],
|
||||
# [4]]
|
||||
suffixes = []
|
||||
|
||||
print("Generating LSTM char sequences...")
|
||||
for i in range(len(self.charbase) - self.windowsize - 1):
|
||||
prefixes.append([self.__char_to_idx(char) for char in self.charbase[i:i+self.windowsize]])
|
||||
suffixes += [self.__char_to_idx(char) for char in self.charbase[i+self.windowsize+1]] # Bit stupid wrapping this in a list but removes possible type error
|
||||
|
||||
# Enter numpy terretory NOW
|
||||
self.prefixes = np.array(prefixes)
|
||||
self.suffixes = np.array(suffixes)
|
||||
|
||||
print(f"Prefixes shape: {self.prefixes.shape}")
|
||||
print(f"Suffixes shape: {self.suffixes.shape}")
|
||||
print("Completed.")
|
||||
|
||||
# Interface shit
|
||||
|
||||
# @todo Also save/load generated prefixes
|
||||
def load(self):
|
||||
print(f"Loading \"{self.filename}\" LSTM model with {len(self.charbase)} characters from file.")
|
||||
|
||||
self.model.load_state_dict(torch.load(f"weights/{self.filename}_lstm_model.pt"))
|
||||
|
||||
def train(self):
|
||||
print(f"Training \"{self.filename}\" LSTM model with {len(self.charbase)} characters.")
|
||||
|
||||
# Optimizer initialization, RMSprop for RNN
|
||||
optimizer = optim.RMSprop(self.model.parameters(), lr=self.args["learning_rate"])
|
||||
|
||||
# Defining number of batches
|
||||
num_batches = int(len(self.prefixes) / self.args["batch_size"])
|
||||
|
||||
# Set model in training mode
|
||||
self.model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
# Training pahse
|
||||
for epoch in range(self.args["num_epochs"]):
|
||||
|
||||
# Mini batches
|
||||
for i in range(num_batches):
|
||||
|
||||
# Batch definition
|
||||
try:
|
||||
x_batch = self.prefixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
|
||||
y_batch = self.suffixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
|
||||
except:
|
||||
x_batch = self.prefixes[i * self.args["batch_size"]:]
|
||||
y_batch = self.suffixes[i * self.args["batch_size"]:]
|
||||
|
||||
# Convert numpy array into torch tensors
|
||||
x = torch.from_numpy(x_batch).type(torch.long).to(self.device)
|
||||
y = torch.from_numpy(y_batch).type(torch.long).to(self.device)
|
||||
|
||||
# Feed the model
|
||||
y_pred = self.model(x)
|
||||
|
||||
# Loss calculation
|
||||
loss = F.cross_entropy(y_pred, y.squeeze()).to(self.device)
|
||||
losses += [loss.item()]
|
||||
|
||||
# Clean gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Calculate gradientes
|
||||
loss.backward()
|
||||
|
||||
# Updated parameters
|
||||
optimizer.step()
|
||||
|
||||
print("Epoch: %d , loss: %.5f " % (epoch, loss.item()))
|
||||
|
||||
torch.save(self.model.state_dict(), f"weights/{self.filename}_lstm_model.pt")
|
||||
print(f"Saved \"{self.filename}\" LSTM model to file")
|
||||
|
||||
plt.plot(np.arange(0, len(losses)), losses)
|
||||
plt.title(self.filename)
|
||||
plt.show()
|
||||
|
||||
|
||||
def generate_sentence(self):
|
||||
# Randomly is selected the index from the set of sequences
|
||||
start = np.random.randint(0, len(self.prefixes)-1)
|
||||
|
||||
# Convert back to string to match complete_sentence
|
||||
pattern = "".join([self.__idx_to_char(char) for char in self.prefixes[start]]) # random sequence from the training text
|
||||
|
||||
return self.complete_sentence(pattern)
|
||||
|
||||
def complete_sentence(self, prefix):
|
||||
print("Prefix:", prefix)
|
||||
|
||||
# Convert to indexes np.array
|
||||
pattern = np.array([self.__char_to_idx(char) for char in prefix])
|
||||
|
||||
# Set the model in evalulation mode
|
||||
self.model.eval()
|
||||
|
||||
# Define the softmax function
|
||||
softmax = nn.Softmax(dim=1).to(self.device)
|
||||
|
||||
# In full_prediction we will save the complete prediction
|
||||
full_prediction = pattern.copy()
|
||||
|
||||
print("Generating sentence...")
|
||||
|
||||
# Predic the next characters one by one, append chars to the starting pattern until . is reached, max 500 iterations
|
||||
for _ in range(500):
|
||||
# the numpy patterns is transformed into a tesor-type and reshaped
|
||||
pattern = torch.from_numpy(pattern).type(torch.long).to(self.device)
|
||||
pattern = pattern.view(1,-1)
|
||||
|
||||
# make a prediction given the pattern
|
||||
prediction = self.model(pattern)
|
||||
# it is applied the softmax function to the predicted tensor
|
||||
prediction = softmax(prediction)
|
||||
|
||||
# the prediction tensor is transformed into a numpy array
|
||||
prediction = prediction.squeeze().detach().cpu().numpy()
|
||||
# it is taken the idx with the highest probability
|
||||
arg_max = np.argmax(prediction)
|
||||
|
||||
# the current pattern tensor is transformed into numpy array
|
||||
pattern = pattern.squeeze().detach().cpu().numpy()
|
||||
# the window is sliced 1 character to the right
|
||||
pattern = pattern[1:]
|
||||
# the new pattern is composed by the "old" pattern + the predicted character
|
||||
pattern = np.append(pattern, arg_max)
|
||||
|
||||
# the full prediction is saved
|
||||
full_prediction = np.append(full_prediction, arg_max)
|
||||
|
||||
# Stop on . character
|
||||
if self.__idx_to_char(arg_max) == ".":
|
||||
break
|
||||
|
||||
full_prediction = "".join([self.__idx_to_char(value) for value in full_prediction])
|
||||
print("Generated:", full_prediction)
|
||||
return full_prediction
|
||||
@ -1,82 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import random
|
||||
from textgen import textgen
|
||||
|
||||
from rich.traceback import install
|
||||
install()
|
||||
|
||||
# NOTE: This is word based, not character based
|
||||
# @todo Serialize and save/load model (don't train on the server)
|
||||
# @todo Maybe extract sentence beginnings and use them as starters?
|
||||
|
||||
class MarkovTextGenerator(textgen):
|
||||
# The greater the order (prefix length), the lesser the variation in generation, but the better the sentences (generally).
|
||||
# If the prefix length is high there are less options to choose from, so the sentences are very close to the training text.
|
||||
def __init__(self, order): # Set order here for better interface (only needed for markov model)
|
||||
self.order = order
|
||||
|
||||
def init(self, filename): # Filename is needed for every type of model so it's part of the interface
|
||||
with open(f"./textfiles/{filename}.txt", "r") as file:
|
||||
# Remove all characters except a-zäöüß'.,
|
||||
self.wordbase = re.sub(r"[^a-zäöüß'.,]+", " ", file.read().lower()).split()
|
||||
|
||||
self.word_table = dict()
|
||||
|
||||
def load(self):
|
||||
print(f"Loaded Markov chain of order {self.order} with {len(self.wordbase)} words from file.")
|
||||
|
||||
def train(self):
|
||||
print(f"Training Markov chain of order {self.order} with {len(self.wordbase)} words.")
|
||||
|
||||
# init the frequencies
|
||||
for i in range(len(self.wordbase) - self.order - 1): # Look at every word in range
|
||||
prefix = tuple(self.wordbase[i:i+self.order]) # Look at the next self.order words from current position
|
||||
suffix = self.wordbase[i+self.order] # The next word is the suffix
|
||||
|
||||
if prefix not in self.word_table: # New option wooo
|
||||
self.word_table[prefix] = []
|
||||
|
||||
# if suffix not in self.table[prefix]: # disable for probabilities: if the suffixes are in the list multiple times they are more common
|
||||
self.word_table[prefix].append(suffix)
|
||||
|
||||
print(f"Generated suffixes for {len(self.word_table)} prefixes.")
|
||||
|
||||
# def generate_random(self, n):
|
||||
# fword = random.choice(list(self.word_table.keys())) # Random first word
|
||||
# output = [*fword]
|
||||
|
||||
# for _ in range(self.order, n):
|
||||
# output.append(self.generate_word_by_word(tuple(output[-self.order :])))
|
||||
|
||||
# return output
|
||||
|
||||
def generate_suffix_for_prefix(self, prefix: tuple):
|
||||
if len(prefix) > self.order: # In this case we look at the last self.order elements of prefix
|
||||
prefix = prefix[len(prefix)-self.order-1:-1]
|
||||
|
||||
if prefix not in self.word_table: # In this case we need to choose a possible suffix from the last word in the prefix (if prefix too short for example)
|
||||
print(f"Prefix {prefix} not in table")
|
||||
for key in self.word_table.keys():
|
||||
if key[-1] == prefix[-1]:
|
||||
return random.choice(self.word_table[key])
|
||||
|
||||
return random.choice(self.word_table[prefix])
|
||||
|
||||
def generate_sentence(self):
|
||||
fword = random.choice(list(self.word_table.keys()))
|
||||
output = [*fword]
|
||||
|
||||
while "." not in output[-1]:
|
||||
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
|
||||
|
||||
return output
|
||||
|
||||
def complete_sentence(self, prefix):
|
||||
output = [*prefix]
|
||||
|
||||
while "." not in output[-1]:
|
||||
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user