From abe1d62fc043c292a2041c8b0015bd09b95e017a Mon Sep 17 00:00:00 2001 From: ChUrl Date: Sun, 13 Nov 2022 16:35:18 +0100 Subject: [PATCH] add python ML flake --- flakes/python_pytorch.nix | 70 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 flakes/python_pytorch.nix diff --git a/flakes/python_pytorch.nix b/flakes/python_pytorch.nix new file mode 100644 index 00000000..39c3176b --- /dev/null +++ b/flakes/python_pytorch.nix @@ -0,0 +1,70 @@ +{ + description = "Machine Learning Environment"; + + inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + inputs.flake-utils.url = "github:numtide/flake-utils"; + inputs.devshell.url = "github:numtide/devshell"; + + outputs = { self, nixpkgs, flake-utils, devshell }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + overlays = [ devshell.overlay ]; + }; + + # TODO: Originally it was nixpkgs.fetchurl but that didn't work, pkgs.fetchurl did... + # Determine the difference between nixpkgs and pkgs + + # NOTE: These packages have to be updated manually! + + # Taken from: https://github.com/gbtb/nix-stable-diffusion/blob/master/flake.nix + # Overlay: https://nixos.wiki/wiki/Overlays + # FetchURL: https://ryantm.github.io/nixpkgs/builders/fetchers/ + torch-rocm = pkgs.hiPrio (pkgs.python310Packages.torch-bin.overrideAttrs (old: { + src = pkgs.fetchurl { + name = "torch-1.12.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl"; + url = "https://download.pytorch.org/whl/rocm5.1.1/torch-1.12.1%2Brocm5.1.1-cp310-cp310-linux_x86_64.whl"; + hash = "sha256-kNShDx88BZjRQhWgnsaJAT8hXnStVMU1ugPNMEJcgnA="; + }; + })); + torchvision-rocm = pkgs.hiPrio (pkgs.python310Packages.torchvision-bin.overrideAttrs (old: { + src = pkgs.fetchurl { + name = "torchvision-0.13.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl"; + url = "https://download.pytorch.org/whl/rocm5.1.1/torchvision-0.13.1%2Brocm5.1.1-cp310-cp310-linux_x86_64.whl"; + hash = "sha256-mYk4+XNXU6rjpgWfKUDq+5fH/HNPQ5wkEtAgJUDN/Jg="; + }; + })); + + myPython = pkgs.python310.withPackages (p: with p; [ + # Basic + rich + + # MachineLearning + torch-rocm + torchvision-rocm + numpy + matplotlib + nltk + ]); + in { + devShell = pkgs.devshell.mkShell { + name = "Machine Learning Environment"; + + packages = with pkgs; [ + myPython + nodePackages.pyright # LSP + ]; + + # Use $1 for positional args + commands = [ + # { + # name = ""; + # help = ""; + # command = ""; + # } + ]; + }; + }); +}