8 Commits

Author SHA1 Message Date
c2847de7dd Add instantbuttons command + make responses ephemeral
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
/instantbuttons displays a soundboard via a button ui
2023-12-09 19:51:28 +01:00
08230eb3de Enforce heidi_spam channel for commands
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
2023-12-09 18:44:16 +01:00
f2ddb4ab66 Only play entrance sound when other is present + reformat
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
2023-12-09 18:04:24 +01:00
876232f674 Ignore user config file 2023-12-09 18:03:42 +01:00
d7c3a7c740 Allow sounds with different file extensions
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 14s
Before only .mkv files could be played, as the extension was hardcoded
2023-12-09 17:55:21 +01:00
bdcd5208a7 Untrack Heidi_User.conf 2023-12-09 17:54:56 +01:00
79fcf0142a Some more options for randomly selected answers 2023-12-09 17:48:27 +01:00
0f6cc12182 Delete orphaned code
All checks were successful
Build Heidi Docker image / build-docker (push) Successful in 27s
2023-12-09 17:36:53 +01:00
13 changed files with 133 additions and 558 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ Pipfile.lock
/disabled_voicelines/ /disabled_voicelines/
*.svg *.svg
.vscode .vscode
Heidi_User.conf

View File

@ -1,38 +0,0 @@
workflow: # for entire pipeline
rules:
- if: '$CI_COMMIT_REF_NAME == "master"' # only run on master...
changes: # ...and when these files have changed
- "*.py"
- "Dockerfile"
docker-build:
stage: build
image: docker:20 # provides the docker toolset (but without an active daemon)
services: # configure images that run during jobs linked to the image (above)
- docker:dind # dind build on docker and starts up the dockerdaemon (docker itself doesn't do that), which is needed to call docker build etc.
before_script:
- docker login -u $CI_REGISTRY_USER -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
script:
- docker pull $CI_REGISTRY_IMAGE:latest || true # latest image for cache (not failing if image is not found)
- >
docker build
--pull
--cache-from $CI_REGISTRY_IMAGE:latest
--label "org.opencontainers.image.title=$CI_PROJECT_TITLE"
--label "org.opencontainers.image.url=$CI_PROJECT_URL"
--label "org.opencontainers.image.created=$CI_JOB_STARTED_AT"
--label "org.opencontainers.image.revision=$CI_COMMIT_SHA"
--label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME"
--tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
.
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA $CI_REGISTRY_IMAGE:latest
- docker push $CI_REGISTRY_IMAGE:latest
docker-deploy:
stage: deploy
image: alpine:3.15
needs: ["docker-build"]
script:
- chmod og= $ID_RSA
- apk update && apk add openssh-client
- ssh -i $ID_RSA -o StrictHostKeyChecking=no $SERVER_USER@$SERVER_IP "/home/christoph/$CI_PROJECT_TITLE/launch.sh"

View File

@ -1,2 +0,0 @@
[ENTRANCE.SOUND]

72
bot.py
View File

@ -120,8 +120,7 @@ def user_entrance_sound_autocomplete(
""" """
boards: List[str] = os.listdir(SOUNDDIR) boards: List[str] = os.listdir(SOUNDDIR)
all_sounds: Dict[str, List[str]] = { all_sounds: Dict[str, List[str]] = {
board: list(map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/"))) board: os.listdir(f"{SOUNDDIR}/{board}/") for board in boards
for board in boards
} # These are all sounds, organized per board } # These are all sounds, organized per board
# @todo Initially only suggest boards, because there are too many sounds to show them all # @todo Initially only suggest boards, because there are too many sounds to show them all
@ -133,7 +132,7 @@ def user_entrance_sound_autocomplete(
for sound in board_sounds: # Iterate over board specific sounds for sound in board_sounds: # Iterate over board specific sounds
soundpath = f"{board}/{sound}" soundpath = f"{board}/{sound}"
if soundpath.lower().startswith(current.lower()): if soundpath.lower().startswith(current.lower()):
completions += [Choice(name=soundpath, value=soundpath)] completions += [Choice(name=soundpath.split(".")[0], value=soundpath)]
return completions return completions
@ -150,6 +149,7 @@ def user_entrance_sound_autocomplete(
config_value="Der Wert, auf welche die Option gesetzt werden soll." config_value="Der Wert, auf welche die Option gesetzt werden soll."
) )
@app_commands.autocomplete(config_value=user_config_value_autocomplete) @app_commands.autocomplete(config_value=user_config_value_autocomplete)
@enforce_channel(HEIDI_SPAM_ID)
async def user_config( async def user_config(
interaction: Interaction, config_key: str, config_value: str interaction: Interaction, config_key: str, config_value: str
) -> None: ) -> None:
@ -159,7 +159,7 @@ async def user_config(
# Only Members can set settings # Only Members can set settings
if not isinstance(interaction.user, Member): if not isinstance(interaction.user, Member):
print("User not a member") print("User not a member")
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!") await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!", ephemeral=True)
return return
member: Member = interaction.user member: Member = interaction.user
@ -168,7 +168,8 @@ async def user_config(
client.write_user_config() client.write_user_config()
await interaction.response.send_message( await interaction.response.send_message(
f"Ok, ich schreibe {member.name}={config_value} in mein fettes Gehirn!" f"Ok, ich schreibe {member.name}={config_value} in mein fettes Gehirn!",
ephemeral=True
) )
@ -176,6 +177,7 @@ async def user_config(
@client.tree.command(name="heidi", description="Heidi!") @client.tree.command(name="heidi", description="Heidi!")
@enforce_channel(HEIDI_SPAM_ID)
async def heidi_exclaim(interaction: Interaction) -> None: async def heidi_exclaim(interaction: Interaction) -> None:
""" """
Print a random Heidi quote. Print a random Heidi quote.
@ -188,6 +190,10 @@ async def heidi_exclaim(interaction: Interaction) -> None:
"Warum denn so schüchtern?", "Warum denn so schüchtern?",
"Im TV ist das legal!", "Im TV ist das legal!",
"Das Stroh ist nur fürs Shooting!", "Das Stroh ist nur fürs Shooting!",
"Jetzt sei doch mal sexy!",
"Stell dich nicht so an!",
"Models müssen da halt durch!",
"Heul doch nicht!",
] ]
await interaction.response.send_message(random.choice(messages)) await interaction.response.send_message(random.choice(messages))
@ -195,6 +201,7 @@ async def heidi_exclaim(interaction: Interaction) -> None:
@client.tree.command(name="miesmuschel", description="Was denkt Heidi?") @client.tree.command(name="miesmuschel", description="Was denkt Heidi?")
@app_commands.rename(question="frage") @app_commands.rename(question="frage")
@app_commands.describe(question="Heidi wird es beantworten!") @app_commands.describe(question="Heidi wird es beantworten!")
@enforce_channel(HEIDI_SPAM_ID)
async def magic_shell(interaction: Interaction, question: str) -> None: async def magic_shell(interaction: Interaction, question: str) -> None:
""" """
Answer a yes/no question. Answer a yes/no question.
@ -208,14 +215,13 @@ async def magic_shell(interaction: Interaction, question: str) -> None:
"Klaro Karo", "Klaro Karo",
"Offensichtlich Sherlock", "Offensichtlich Sherlock",
"Tom sagt Ja", "Tom sagt Ja",
"Nein!", "Nein!",
"Nö.", "Nö.",
"Nä.", "Nä.",
"Niemals!", "Niemals!",
"Nur über meine Leiche du Hurensohn!", "Nur über meine Leiche du Hurensohn!",
"In deinen Träumen.", "In deinen Träumen.",
"Tom sagt Nein" "Tom sagt Nein",
] ]
question = question.strip() question = question.strip()
question_mark = "" if question[-1] == "?" else "?" question_mark = "" if question[-1] == "?" else "?"
@ -230,6 +236,7 @@ async def magic_shell(interaction: Interaction, question: str) -> None:
@app_commands.describe(option_a="Ist es vielleicht dies?") @app_commands.describe(option_a="Ist es vielleicht dies?")
@app_commands.rename(option_b="oder") @app_commands.rename(option_b="oder")
@app_commands.describe(option_b="Oder doch eher das?") @app_commands.describe(option_b="Oder doch eher das?")
@enforce_channel(HEIDI_SPAM_ID)
async def choose(interaction: Interaction, option_a: str, option_b: str) -> None: async def choose(interaction: Interaction, option_a: str, option_b: str) -> None:
""" """
Select an answer from two options. Select an answer from two options.
@ -265,12 +272,10 @@ async def sound_autocomplete(
Suggest a sound from an already selected board. Suggest a sound from an already selected board.
""" """
board: str = interaction.namespace.board board: str = interaction.namespace.board
sounds: List[str] = list( sounds: List[str] = os.listdir(f"{SOUNDDIR}/{board}/")
map(lambda x: x.split(".")[0], os.listdir(f"{SOUNDDIR}/{board}/"))
)
return [ return [
Choice(name=sound, value=sound) Choice(name=sound.split(".")[0], value=sound)
for sound in sounds for sound in sounds
if sound.lower().startswith(current.lower()) if sound.lower().startswith(current.lower())
] ]
@ -282,6 +287,7 @@ async def sound_autocomplete(
@app_commands.describe(sound="Was soll Heidi sagen?") @app_commands.describe(sound="Was soll Heidi sagen?")
@app_commands.autocomplete(board=board_autocomplete) @app_commands.autocomplete(board=board_autocomplete)
@app_commands.autocomplete(sound=sound_autocomplete) @app_commands.autocomplete(sound=sound_autocomplete)
@enforce_channel(HEIDI_SPAM_ID)
async def say_voiceline(interaction: Interaction, board: str, sound: str) -> None: async def say_voiceline(interaction: Interaction, board: str, sound: str) -> None:
""" """
Play a voiceline in the calling member's current voice channel. Play a voiceline in the calling member's current voice channel.
@ -289,7 +295,7 @@ async def say_voiceline(interaction: Interaction, board: str, sound: str) -> Non
# Only Members can access voice channels # Only Members can access voice channels
if not isinstance(interaction.user, Member): if not isinstance(interaction.user, Member):
print("User not a member") print("User not a member")
await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!") await interaction.response.send_message("Heidi sagt: Komm in die Gruppe!", ephemeral=True)
return return
member: Member = interaction.user member: Member = interaction.user
@ -297,6 +303,43 @@ async def say_voiceline(interaction: Interaction, board: str, sound: str) -> Non
await play_voice_line_for_member(interaction, member, board, sound) await play_voice_line_for_member(interaction, member, board, sound)
class InstantButton(discord.ui.Button):
def __init__(self, label: str, board: str, sound: str):
super().__init__(style=discord.ButtonStyle.red, label=label)
self.board = board
self.sound = sound
async def callback(self, interaction: Interaction):
"""
Handle a press of the button.
"""
if not isinstance(interaction.user, Member):
await interaction.response.send_message("Heidi mag keine discord.User, nur discord.Member!", ephemeral=True)
return
await play_voice_line_for_member(interaction, interaction.user, self.board, self.sound)
class InstantButtons(discord.ui.View):
def __init__(self, board: str, timeout=None):
super().__init__(timeout=timeout)
sounds = os.listdir(f"{SOUNDDIR}/{board}")
for sound in sounds:
self.add_item(InstantButton(sound.split(".")[0], board, sound))
@client.tree.command(
name="instantbuttons", description="Heidi malt Knöpfe für Sounds in den Chat."
)
@app_commands.describe(board="Welches Soundboard soll knöpfe bekommen?")
@app_commands.autocomplete(board=board_autocomplete)
@enforce_channel(HEIDI_SPAM_ID)
async def soundboard_buttons(interaction: Interaction, board: str) -> None:
await interaction.response.send_message(f"Soundboard: {board.capitalize()}", view=InstantButtons(board))
# Contextmenu ------------------------------------------------------------------------------------ # Contextmenu ------------------------------------------------------------------------------------
@ -313,7 +356,7 @@ async def insult(
if not member.dm_channel: if not member.dm_channel:
print("Error creating DMChannel!") print("Error creating DMChannel!")
await interaction.response.send_message("Heidi sagt: Gib mal DM Nummer süße*r!") await interaction.response.send_message("Heidi sagt: Gib mal DM Nummer süße*r!", ephemeral=True)
return return
insults = [ insults = [
@ -334,7 +377,8 @@ async def insult(
await member.dm_channel.send(random.choice(insults)) await member.dm_channel.send(random.choice(insults))
await interaction.response.send_message( await interaction.response.send_message(
"Anzeige ist raus!" "Anzeige ist raus!",
ephemeral=True
) # with ephemeral = True only the caller can see the answer ) # with ephemeral = True only the caller can see the answer

View File

@ -28,7 +28,7 @@ class HeidiClient(discord.Client):
# lambda m: m.author.nick.lower() in self.models.get_in_names(): self.autoreact_to_girls, # lambda m: m.author.nick.lower() in self.models.get_in_names(): self.autoreact_to_girls,
lambda m: "jeremy" in m.author.nick.lower(): self._autoreact_to_jeremy, lambda m: "jeremy" in m.author.nick.lower(): self._autoreact_to_jeremy,
lambda m: "kardashian" in m.author.nick.lower() lambda m: "kardashian" in m.author.nick.lower()
or "jenner" in m.author.nick.lower(): self._autoreact_to_kardashian, or "jenner" in m.author.nick.lower(): self._autoreact_to_kardashian,
} }
# automatic actions on voice state changes # automatic actions on voice state changes
@ -109,15 +109,26 @@ class HeidiClient(discord.Client):
await message.add_reaction("💄") await message.add_reaction("💄")
async def _play_entrance_sound( async def _play_entrance_sound(
self, self,
member: Member, member: Member,
before: VoiceState, before: VoiceState,
after: VoiceState, after: VoiceState,
) -> None: ) -> None:
""" """
Play a sound when a member joins a voice channel. Play a sound when a member joins a voice channel (and another member is present).
This function is set in on_voice_state_triggers and triggered by the on_voice_state_update event. This function is set in on_voice_state_triggers and triggered by the on_voice_state_update event.
""" """
# Don't play anything when no other users are present
if (
member is not None
and member.voice is not None
and member.voice.channel is not None
and len(member.voice.channel.members) <= 1
):
print("Not playing entrance sound, as no other members are present")
return
soundpath: Union[str, None] = self.user_config["ENTRANCE.SOUND"].get( soundpath: Union[str, None] = self.user_config["ENTRANCE.SOUND"].get(
member.name, None member.name, None
) )

View File

@ -24,3 +24,6 @@ SOUNDDIR: str = "/sounds" if DOCKER else "./heidi-sounds"
# IDs of the servers Heidi is used on # IDs of the servers Heidi is used on
LINUS_GUILD = discord.Object(id=431154792308408340) LINUS_GUILD = discord.Object(id=431154792308408340)
TEST_GUILD = discord.Object(id=821511861178204161) TEST_GUILD = discord.Object(id=821511861178204161)
# Channel IDs
HEIDI_SPAM_ID = 822223476101742682

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import functools
from typing import Union from typing import Union
import discord import discord
@ -8,31 +9,66 @@ from heidi_constants import *
print("Debug: Importing heidi_helpers.py") print("Debug: Importing heidi_helpers.py")
# Checks -----------------------------------------------------------------------------------------
# 1. @enforce_channel(ID) is added to a function, which evaluates to decorate with the channel_id in its closure
# 2. The function is passed to decorate(function),
def enforce_channel(channel_id):
"""
Only run a function if called from the correct channel.
"""
def decorate(function):
@functools.wraps(function)
async def wrapped(*args, **kwargs):
"""
Sends an interaction response if the interaction is not triggered from the heidi_spam channel.
"""
interaction: Interaction = args[0]
# Do not call the decorated function if the channel_id doesn't match
if not interaction.channel_id == channel_id:
await interaction.response.send_message("Heidi sagt: Geh in heidi_spam du dulli", ephemeral=True)
return
await function(*args, **kwargs)
return wrapped
return decorate
# Sounds -----------------------------------------------------------------------------------------
# @todo Normalize volume when playing # @todo Normalize volume when playing
async def play_voice_line( async def play_voice_line(
interaction: Union[Interaction, None], interaction: Union[Interaction, None],
voice_channel: VoiceChannel, voice_channel: VoiceChannel,
board: str, board: str,
sound: str, sound: str,
) -> None: ) -> None:
""" """
Play a voice line in the specified channel. Play a voice line in the specified channel.
""" """
try: try:
open(f"{SOUNDDIR}/{board}/{sound}.mkv") open(f"{SOUNDDIR}/{board}/{sound}")
except IOError: except IOError:
print("Error: Invalid soundfile!") print(f"Error: Invalid soundfile {SOUNDDIR}/{board}/{sound}!")
if interaction is not None: if interaction is not None:
await interaction.response.send_message( await interaction.response.send_message(
f'Heidi sagt: "{board}/{sound}" kanninich finden bruder' f'Heidi sagt: "{board}/{sound}" kanninich finden bruder',
ephemeral=True
) )
return return
if interaction is not None: if interaction is not None:
await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"') await interaction.response.send_message(f'Heidi sagt: "{board}/{sound}"', ephemeral=True)
audio_source = discord.FFmpegPCMAudio( audio_source = discord.FFmpegPCMAudio(
f"{SOUNDDIR}/{board}/{sound}.mkv" f"{SOUNDDIR}/{board}/{sound}"
) # only works from docker ) # only works from docker
voice_client = await voice_channel.connect() voice_client = await voice_channel.connect()
voice_client.play(audio_source) voice_client.play(audio_source)
@ -44,24 +80,24 @@ async def play_voice_line(
async def play_voice_line_for_member( async def play_voice_line_for_member(
interaction: Union[Interaction, None], interaction: Union[Interaction, None],
member: Member, member: Member,
board: str, board: str,
sound: str, sound: str,
) -> None: ) -> None:
""" """
Play a voice line in the member's current channel. Play a voice line in the member's current channel.
""" """
# Member needs to be in voice channel to hear audio (Heidi needs to know the channel to join) # Member needs to be in voice channel to hear audio (Heidi needs to know the channel to join)
if ( if (
member is None member is None
or member.voice is None or member.voice is None
or member.voice.channel is None or member.voice.channel is None
or not isinstance(member.voice.channel, VoiceChannel) or not isinstance(member.voice.channel, VoiceChannel)
): ):
print("User not in (valid) voice channel!") print("User not in (valid) voice channel!")
if interaction is not None: if interaction is not None:
await interaction.response.send_message("Heidi sagt: Komm in den Channel!") await interaction.response.send_message("Heidi sagt: Komm in den Channel!", ephemeral=True)
return return
voice_channel: VoiceChannel = member.voice.channel voice_channel: VoiceChannel = member.voice.channel

View File

@ -1,9 +0,0 @@
#!/bin/sh
cd /home/christoph/HeidiBot
git pull
docker pull registry.gitlab.com/churl/heidibot
docker container rm -f heidibot
docker run -d --env-file /home/christoph/HeidiBot/.env --mount src=/home/christoph/HeidiBot/voicelines,target=/sounds,type=bind --name heidibot registry.gitlab.com/churl/heidibot
docker image prune -f

View File

@ -1,33 +0,0 @@
#!/usr/bin/env python3
import requests
import re
from bs4 import BeautifulSoup
class Models:
def __init__(self):
url_girls = "https://www.prosieben.de/tv/germanys-next-topmodel/models"
html_girls = requests.get(url_girls)
soup_girls = BeautifulSoup(html_girls.text, "html.parser")
girls_in = soup_girls.findAll("a", class_="candidate-in")
girls_out = soup_girls.findAll("a", class_="candidate-out")
self.girls_in = {girl.get("title").lower(): girl for girl in girls_in}
self.girls_out = {girl.get("title").lower(): girl for girl in girls_out}
self.girls = {**self.girls_in, **self.girls_out}
def get_in_names(self):
return self.girls_in.keys()
def get_out_names(self):
return self.girls_out.keys()
def get_image(self, name):
style = self.girls[name.lower()].find("figure", class_="teaser-img").get("style")
url = re.search(r"url\(.*\);", style).group()
return url[4:-9] + "562x996" # increase resolution

View File

@ -3,12 +3,3 @@ rich
discord.py # maintained again discord.py # maintained again
pynacl # voice support pynacl # voice support
python-dotenv # discord guild secrets python-dotenv # discord guild secrets
# Webscraping
# requests
# beautifulsoup4
# Textgeneration
# torch
# numpy
# nltk

View File

@ -1,44 +0,0 @@
#!/usr/bin/env python3
from rich.traceback import install
install()
from abc import ABC, abstractmethod
# In Python it is generally not needed to use abstract classes, but I wanted to do it safely
class textgen(ABC):
@abstractmethod
def init(self, filename):
"""
filename - The file (same directory as textgen.py) that contains the training text
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def load(self):
"""
Load the trained markov chain from a precomputed file
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def train(self):
"""
Generate the markov chain, uses prefix length defined in init()
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def generate_sentence(self):
"""
Generate a series of words/characters until a . is generated
"""
raise NotImplementedError("Can't use abstract class")
@abstractmethod
def complete_sentence(self, prefix):
"""
Generate the rest of a sentence for a given beginning
"""
raise NotImplementedError("Can't use abstract class")

View File

@ -1,303 +0,0 @@
#!/usr/bin/env python3
import re, random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from textgen import textgen
from torch import nn, optim
from rich.traceback import install
install()
# Model =======================================================================================
# https://towardsdatascience.com/text-generation-with-bi-lstm-in-pytorch-5fda6e7cc22c
# Embedding -> Bi-LSTM -> LSTM -> Linear
class Model(nn.ModuleList):
def __init__(self, args, device):
super(Model, self).__init__()
self.device = device
self.batch_size = args["batch_size"]
self.hidden_dim = args["hidden_dim"]
self.input_size = args["vocab_size"]
self.num_classes = args["vocab_size"]
self.sequence_len = args["window"]
# Dropout
self.dropout = nn.Dropout(0.25) # Don't need to set device for the layers as we transfer the whole model later
# Embedding layer
self.embedding = nn.Embedding(self.input_size, self.hidden_dim, padding_idx=0)
# Bi-LSTM
# Forward and backward
self.lstm_cell_forward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
self.lstm_cell_backward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
# LSTM layer
self.lstm_cell = nn.LSTMCell(self.hidden_dim * 2, self.hidden_dim * 2)
# Linear layer
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
def forward(self, x):
# Bi-LSTM
# hs = [batch_size x hidden_size]
# cs = [batch_size x hidden_size]
hs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device) # Need to specify device here as this is not part of the model directly
cs_forward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
hs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
cs_backward = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
# LSTM
# hs = [batch_size x (hidden_size * 2)]
# cs = [batch_size x (hidden_size * 2)]
hs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
cs_lstm = torch.zeros(x.size(0), self.hidden_dim * 2).to(self.device)
# Weights initialization
torch.nn.init.kaiming_normal_(hs_forward)
torch.nn.init.kaiming_normal_(cs_forward)
torch.nn.init.kaiming_normal_(hs_backward)
torch.nn.init.kaiming_normal_(cs_backward)
torch.nn.init.kaiming_normal_(hs_lstm)
torch.nn.init.kaiming_normal_(cs_lstm)
# From idx to embedding
out = self.embedding(x)
# Prepare the shape for LSTM Cells
out = out.view(self.sequence_len, x.size(0), -1)
forward = []
backward = []
# Unfolding Bi-LSTM
# Forward
for i in range(self.sequence_len):
hs_forward, cs_forward = self.lstm_cell_forward(out[i], (hs_forward, cs_forward))
forward.append(hs_forward)
# Backward
for i in reversed(range(self.sequence_len)):
hs_backward, cs_backward = self.lstm_cell_backward(out[i], (hs_backward, cs_backward))
backward.append(hs_backward)
# LSTM
for fwd, bwd in zip(forward, backward):
input_tensor = torch.cat((fwd, bwd), 1)
hs_lstm, cs_lstm = self.lstm_cell(input_tensor, (hs_lstm, cs_lstm))
# Last hidden state is passed through a linear layer
out = self.linear(hs_lstm)
return out
# =============================================================================================
class LSTMTextGenerator(textgen):
def __init__(self, windowsize):
self.windowsize = windowsize # We slide a window over the character sequence and look at the next letter,
# similar to the Markov chain order
def init(self, filename):
self.filename = filename
# Use this to generate one hot vector and filter characters
self.letters = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "ä", "ö", "ü", ".", " "]
with open(f"./textfiles/{filename}.txt", "r") as file:
lines = [line.lower() for line in file.readlines()] # lowercase list
text = " ".join(lines) # single string
self.charbase = [char for char in text if char in self.letters] # list of characters
# Select device
if torch.cuda.is_available():
dev = "cuda:0"
print("Selected GPU for LSTM")
else:
dev = "cpu"
print("Selected CPU for LSTM")
self.device = torch.device(dev)
# Init model
self.args = {
"window": self.windowsize,
"hidden_dim": 128,
"vocab_size": len(self.letters),
"batch_size": 128,
"learning_rate": 0.0005,
"num_epochs": 100
}
self.model = Model(self.args, self.device)
self.model.to(self.device) # All model layers need to use the correct tensors (cpu/gpu)
# Needed for both training and generation
self.__generate_char_sequences()
# Helper shit
def __char_to_idx(self, char):
return self.letters.index(char)
def __idx_to_char(self, idx):
return self.letters[idx]
def __generate_char_sequences(self):
# Example
# [[21, 20, 15],
# [12, 12, 14]]
prefixes = []
# Example
# [[1],
# [4]]
suffixes = []
print("Generating LSTM char sequences...")
for i in range(len(self.charbase) - self.windowsize - 1):
prefixes.append([self.__char_to_idx(char) for char in self.charbase[i:i+self.windowsize]])
suffixes += [self.__char_to_idx(char) for char in self.charbase[i+self.windowsize+1]] # Bit stupid wrapping this in a list but removes possible type error
# Enter numpy terretory NOW
self.prefixes = np.array(prefixes)
self.suffixes = np.array(suffixes)
print(f"Prefixes shape: {self.prefixes.shape}")
print(f"Suffixes shape: {self.suffixes.shape}")
print("Completed.")
# Interface shit
# @todo Also save/load generated prefixes
def load(self):
print(f"Loading \"{self.filename}\" LSTM model with {len(self.charbase)} characters from file.")
self.model.load_state_dict(torch.load(f"weights/{self.filename}_lstm_model.pt"))
def train(self):
print(f"Training \"{self.filename}\" LSTM model with {len(self.charbase)} characters.")
# Optimizer initialization, RMSprop for RNN
optimizer = optim.RMSprop(self.model.parameters(), lr=self.args["learning_rate"])
# Defining number of batches
num_batches = int(len(self.prefixes) / self.args["batch_size"])
# Set model in training mode
self.model.train()
losses = []
# Training pahse
for epoch in range(self.args["num_epochs"]):
# Mini batches
for i in range(num_batches):
# Batch definition
try:
x_batch = self.prefixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
y_batch = self.suffixes[i * self.args["batch_size"]:(i + 1) * self.args["batch_size"]]
except:
x_batch = self.prefixes[i * self.args["batch_size"]:]
y_batch = self.suffixes[i * self.args["batch_size"]:]
# Convert numpy array into torch tensors
x = torch.from_numpy(x_batch).type(torch.long).to(self.device)
y = torch.from_numpy(y_batch).type(torch.long).to(self.device)
# Feed the model
y_pred = self.model(x)
# Loss calculation
loss = F.cross_entropy(y_pred, y.squeeze()).to(self.device)
losses += [loss.item()]
# Clean gradients
optimizer.zero_grad()
# Calculate gradientes
loss.backward()
# Updated parameters
optimizer.step()
print("Epoch: %d , loss: %.5f " % (epoch, loss.item()))
torch.save(self.model.state_dict(), f"weights/{self.filename}_lstm_model.pt")
print(f"Saved \"{self.filename}\" LSTM model to file")
plt.plot(np.arange(0, len(losses)), losses)
plt.title(self.filename)
plt.show()
def generate_sentence(self):
# Randomly is selected the index from the set of sequences
start = np.random.randint(0, len(self.prefixes)-1)
# Convert back to string to match complete_sentence
pattern = "".join([self.__idx_to_char(char) for char in self.prefixes[start]]) # random sequence from the training text
return self.complete_sentence(pattern)
def complete_sentence(self, prefix):
print("Prefix:", prefix)
# Convert to indexes np.array
pattern = np.array([self.__char_to_idx(char) for char in prefix])
# Set the model in evalulation mode
self.model.eval()
# Define the softmax function
softmax = nn.Softmax(dim=1).to(self.device)
# In full_prediction we will save the complete prediction
full_prediction = pattern.copy()
print("Generating sentence...")
# Predic the next characters one by one, append chars to the starting pattern until . is reached, max 500 iterations
for _ in range(500):
# the numpy patterns is transformed into a tesor-type and reshaped
pattern = torch.from_numpy(pattern).type(torch.long).to(self.device)
pattern = pattern.view(1,-1)
# make a prediction given the pattern
prediction = self.model(pattern)
# it is applied the softmax function to the predicted tensor
prediction = softmax(prediction)
# the prediction tensor is transformed into a numpy array
prediction = prediction.squeeze().detach().cpu().numpy()
# it is taken the idx with the highest probability
arg_max = np.argmax(prediction)
# the current pattern tensor is transformed into numpy array
pattern = pattern.squeeze().detach().cpu().numpy()
# the window is sliced 1 character to the right
pattern = pattern[1:]
# the new pattern is composed by the "old" pattern + the predicted character
pattern = np.append(pattern, arg_max)
# the full prediction is saved
full_prediction = np.append(full_prediction, arg_max)
# Stop on . character
if self.__idx_to_char(arg_max) == ".":
break
full_prediction = "".join([self.__idx_to_char(value) for value in full_prediction])
print("Generated:", full_prediction)
return full_prediction

View File

@ -1,82 +0,0 @@
#!/usr/bin/env python3
import re
import random
from textgen import textgen
from rich.traceback import install
install()
# NOTE: This is word based, not character based
# @todo Serialize and save/load model (don't train on the server)
# @todo Maybe extract sentence beginnings and use them as starters?
class MarkovTextGenerator(textgen):
# The greater the order (prefix length), the lesser the variation in generation, but the better the sentences (generally).
# If the prefix length is high there are less options to choose from, so the sentences are very close to the training text.
def __init__(self, order): # Set order here for better interface (only needed for markov model)
self.order = order
def init(self, filename): # Filename is needed for every type of model so it's part of the interface
with open(f"./textfiles/{filename}.txt", "r") as file:
# Remove all characters except a-zäöüß'.,
self.wordbase = re.sub(r"[^a-zäöüß'.,]+", " ", file.read().lower()).split()
self.word_table = dict()
def load(self):
print(f"Loaded Markov chain of order {self.order} with {len(self.wordbase)} words from file.")
def train(self):
print(f"Training Markov chain of order {self.order} with {len(self.wordbase)} words.")
# init the frequencies
for i in range(len(self.wordbase) - self.order - 1): # Look at every word in range
prefix = tuple(self.wordbase[i:i+self.order]) # Look at the next self.order words from current position
suffix = self.wordbase[i+self.order] # The next word is the suffix
if prefix not in self.word_table: # New option wooo
self.word_table[prefix] = []
# if suffix not in self.table[prefix]: # disable for probabilities: if the suffixes are in the list multiple times they are more common
self.word_table[prefix].append(suffix)
print(f"Generated suffixes for {len(self.word_table)} prefixes.")
# def generate_random(self, n):
# fword = random.choice(list(self.word_table.keys())) # Random first word
# output = [*fword]
# for _ in range(self.order, n):
# output.append(self.generate_word_by_word(tuple(output[-self.order :])))
# return output
def generate_suffix_for_prefix(self, prefix: tuple):
if len(prefix) > self.order: # In this case we look at the last self.order elements of prefix
prefix = prefix[len(prefix)-self.order-1:-1]
if prefix not in self.word_table: # In this case we need to choose a possible suffix from the last word in the prefix (if prefix too short for example)
print(f"Prefix {prefix} not in table")
for key in self.word_table.keys():
if key[-1] == prefix[-1]:
return random.choice(self.word_table[key])
return random.choice(self.word_table[prefix])
def generate_sentence(self):
fword = random.choice(list(self.word_table.keys()))
output = [*fword]
while "." not in output[-1]:
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
return output
def complete_sentence(self, prefix):
output = [*prefix]
while "." not in output[-1]:
output.append(self.generate_suffix_for_prefix(tuple(output[-self.order:])))
return output