diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2dc53ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..71ebc5a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + name: isort (python) + args: + - "--profile=black" + + - repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black + args: + - --line-length=88 + - --include='\.pyi?$' + + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: + - "--max-line-length=88" + - "--max-complexity=18" + - "--select=B,C,E,F,W,T4,B9,c90" + - "--ignore=E203,E266,E501,W503,F403,F401,E402" + +default_language_version: + python: python3.10 diff --git a/README.md b/README.md index a5af77d..87c9f5d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Data Mining Tool +## Requirements +- Python >= 3.10 +- dependencies listed in the file [requirements.txt](requirements.txt) + ## Getting started ```commandline @@ -7,7 +11,8 @@ git clone https://github.com/mhawryluk/data-mining-tool data-mining-tool cd data-mining-tool virtualenv .venv source .venv/bin/activate -python3 -m pip install -r requirements.txt - -python3 src/app.py +python -m pip install -r requirements.txt +export MONGO_PASS= +cd src +python app.py ``` \ No newline at end of file diff --git a/qt-designer-uis/algorithm-widget.ui b/qt-designer-uis/algorithm-widget.ui new file mode 100644 index 0000000..718f527 --- /dev/null +++ b/qt-designer-uis/algorithm-widget.ui @@ -0,0 +1,1547 @@ + + + Form + + + + 0 + 0 + 815 + 491 + + + + Form + + + + + 0 + 0 + 111 + 491 + + + + Qt::LeftToRight + + + background-color: rgb(177, 221, 240); +writing-mode: vertical-rl; +-webkit-transform: rotate(-180deg); +-moz-transform: rotate(-180deg); + + + QFrame::Box + + + ALGORITHM + + + false + + + Qt::AlignCenter + + + true + + + false + + + + + + 109 + -1 + 711 + 501 + + + + false + + + background-color: rgb(245, 252, 255); + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 30 + 90 + 221 + 111 + + + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 252 + 254 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 252 + 254 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 252 + 254 + 255 + + + + + + + 252 + 254 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Exploration technique + + + + + 30 + 50 + 121 + 41 + + + + color: rgb(0,0,0); + + + + clastering + + + + + associations + + + + + + + + 40 + 270 + 221 + 111 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Algorithm + + + + + 30 + 50 + 121 + 41 + + + + color: rgb(0,0,0); + + + + + + + 340 + 100 + 281 + 251 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Options + + + + + 160 + 70 + 81 + 31 + + + + color: rgb(0,0,0) + + + + + + 20 + 60 + 131 + 51 + + + + color: rgb(0,0,0); + + + number of clasters + + + Qt::AlignCenter + + + + + + 20 + 160 + 131 + 51 + + + + color: rgb(0,0,0); + + + initialization method + + + Qt::AlignCenter + + + + + + 170 + 170 + 81 + 41 + + + + color: rgb(0,0,0) + + + Forgy + + + + + + + + diff --git a/qt-designer-uis/import-widget.ui b/qt-designer-uis/import-widget.ui new file mode 100644 index 0000000..4d05294 --- /dev/null +++ b/qt-designer-uis/import-widget.ui @@ -0,0 +1,1673 @@ + + + Form + + + + 0 + 0 + 819 + 491 + + + + Form + + + + + 0 + 0 + 111 + 491 + + + + Qt::LeftToRight + + + background-color: rgb(177, 221, 240); +writing-mode: vertical-rl; +-webkit-transform: rotate(-180deg); +-moz-transform: rotate(-180deg); + + + QFrame::Box + + + IMPORT DATA + + + false + + + Qt::AlignCenter + + + true + + + false + + + + + + 109 + -1 + 711 + 501 + + + + false + + + background-color: rgb(245, 252, 255); + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + true + + + + 30 + 30 + 221 + 161 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Load data + + + + + 10 + 50 + 113 + 23 + + + + + + + 10 + 100 + 111 + 23 + + + + + + + 10 + 30 + 101 + 16 + + + + Set path to file: + + + + + + 10 + 80 + 181 + 16 + + + + Choose data from database: + + + + + + 150 + 50 + 51 + 23 + + + + LOAD + + + + + + 150 + 100 + 51 + 23 + + + + LOAD + + + + + + 10 + 130 + 201 + 16 + + + + Some error + + + + + + + 30 + 220 + 221 + 171 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Options + + + + + 20 + 120 + 191 + 31 + + + + This file is too big. +You must save it in database! + + + + + + 10 + 30 + 201 + 23 + + + + Reject this data + + + + + + 10 + 60 + 201 + 23 + + + + Save to database and set data + + + + + false + + + + 10 + 90 + 201 + 23 + + + + Set data + + + false + + + + + + + 360 + 30 + 321 + 361 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 0 + 0 + 0 + + + + + + + 255 + 255 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 253 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 255 + 255 + 255 + + + + + + + 253 + 254 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 168 + 169 + 170 + + + + + + + 126 + 127 + 127 + + + + + + + 255 + 255 + 255 + + + + + + + 126 + 127 + 127 + + + + + + + 245 + 252 + 255 + + + + + + + 245 + 252 + 255 + + + + + + + 0 + 0 + 0 + + + + + + + 252 + 254 + 255 + + + + + + + 255 + 255 + 220 + + + + + + + 0 + 0 + 0 + + + + + + + 0 + 0 + 0 + + + + + + + + Columns + + + + + 10 + 30 + 301 + 101 + + + + + + + something + + + true + + + + + + + city + + + true + + + + + + + id + + + true + + + + + + + height + + + true + + + + + + + age + + + true + + + + + + + birth date + + + true + + + + + + + name + + + true + + + + + + + + + + + diff --git a/requirements.txt b/requirements.txt index 718cf7c..5350c2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,32 @@ cycler==0.11.0 -fonttools==4.32.0 -kiwisolver==1.4.2 -matplotlib==3.5.1 -numpy==1.22.3 +dnspython==2.2.1 +fonttools==4.38.0 +kiwisolver==1.4.4 +matplotlib==3.6.2 +networkx==2.8.8 +numpy==1.23.4 packaging==21.3 -Pillow==9.1.0 -pyparsing==3.0.8 -PyQt5==5.15.6 +pandas==1.5.1 +Pillow==9.3.0 +plotly==5.11.0 +psutil==5.9.4 +pymongo==4.3.2 +pyparsing==3.0.9 +PyQt5==5.15.7 PyQt5-Qt5==5.15.2 -PyQt5-sip==12.9.1 -PyQt5-stubs==5.15.2.0 +PyQt5-sip==12.11.0 +PyQt5-stubs==5.15.6.0 +PyQtWebEngine==5.15.6 +PyQtWebEngine-Qt5==5.15.2 python-dateutil==2.8.2 +pytz==2022.6 six==1.16.0 +joblib==1.2.0 +pygraphviz~=1.10 +QGraphViz~=0.0.55 +pip==22.3.1 +wheel==0.38.4 +setuptools==65.5.1 +graphviz~=0.20.1 +scipy==1.9.3 +pre-commit diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/algorithms/__init__.py b/src/algorithms/__init__.py index e69de29..6b35e09 100644 --- a/src/algorithms/__init__.py +++ b/src/algorithms/__init__.py @@ -0,0 +1,2 @@ +from .algorithm import Algorithm +from .utils import check_numeric, get_samples, get_threads_count diff --git a/src/algorithms/algorithm.py b/src/algorithms/algorithm.py new file mode 100644 index 0000000..520d2c9 --- /dev/null +++ b/src/algorithms/algorithm.py @@ -0,0 +1,32 @@ +from abc import abstractmethod +from typing import List + + +class Algorithm: + """ + Abstract class of algorithm + """ + + metrics_info = {} + + @abstractmethod + def run(self, with_steps: bool): + """ + Run algorithm and return result for class AlgorithmResultsWidget + If with_steps is true, saves steps of algorithm creation + """ + raise NotImplementedError + + @abstractmethod + def get_steps(self) -> List: + """ + Return list of steps for visualization by AlgorithmStepsVisualization + """ + raise NotImplementedError + + @abstractmethod + def update_metrics(self, *args): + """ + Update metrics_info dict + """ + raise NotImplementedError diff --git a/src/algorithms/associations/__init__.py b/src/algorithms/associations/__init__.py index e69de29..0b3117a 100644 --- a/src/algorithms/associations/__init__.py +++ b/src/algorithms/associations/__init__.py @@ -0,0 +1 @@ +from .a_priori import APriori, APrioriPartLabel diff --git a/src/algorithms/associations/a_priori.py b/src/algorithms/associations/a_priori.py new file mode 100644 index 0000000..a1df925 --- /dev/null +++ b/src/algorithms/associations/a_priori.py @@ -0,0 +1,232 @@ +from enum import Enum +from itertools import chain, combinations +from typing import List, Optional, Tuple + +import pandas as pd + +from algorithms import Algorithm +from utils import format_set + + +class APriori(Algorithm): + def __init__( + self, + data: pd.DataFrame, + index_column: str, + min_support: float, + min_confidence: float, + ): + self.min_support = min_support + self.min_confidence = min_confidence + self.data = data.set_index(index_column) + self.columns = self.data.columns + self.transaction_sets = list( + map( + set, + self.data.apply(lambda x: x > 0).apply( + lambda x: list(self.columns[x.values]), axis=1 + ), + ) + ) + self.all_frequent_sets = {} + self.k_frequent_sets_df = None + self.saved_steps: List[dict] = [] + + def run(self, with_steps) -> Tuple[pd.DataFrame, pd.DataFrame, List[set]]: + frequent_sets = None + + for k in range( + 1, len(self.columns) + ): # k as in k-item_sets - sets that contain k elements + generated_item_sets = self._generate_item_sets(with_steps, frequent_sets) + new_frequent_sets = {} + for item_set, item_set_support in zip( + generated_item_sets, + map(lambda item_set: self.support(item_set), generated_item_sets), + ): + if item_set_support >= self.min_support: + new_frequent_sets[item_set] = item_set_support + + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.CALCULATE_SUPPORT, + "set": item_set, + "support": item_set_support, + "min_support": self.min_support, + "data_frame": self.k_frequent_sets_df, + } + ) + + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.FILTER_BY_SUPPORT, + "frequent_sets": list(new_frequent_sets.keys()), + "infrequent_sets": [ + set_ + for set_ in generated_item_sets + if set_ not in new_frequent_sets + ], + "data_frame": self.k_frequent_sets_df, + } + ) + + if not new_frequent_sets: + break + + self.all_frequent_sets |= new_frequent_sets + self.k_frequent_sets_df = self._get_frequent_set_pd(new_frequent_sets) + + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.SAVE_K_SETS, + "k": k, + "data_frame": self.k_frequent_sets_df, + } + ) + + frequent_sets = new_frequent_sets + + rules = self._get_association_rules(with_steps) + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.SAVE_RULES, + "data_frame": rules, + } + ) + + return ( + self._get_frequent_set_pd(self.all_frequent_sets), + rules, + self.transaction_sets, + ) + + def get_steps(self) -> List[dict]: + return self.saved_steps + + def _get_frequent_set_pd(self, frequent_sets: dict): + return pd.DataFrame.from_dict( + { + format_set(frequent_set): round(self.support(frequent_set), 3) + for frequent_set, support in frequent_sets.items() + }, + orient="index", + columns=["support"], + ).sort_values(by="support", ascending=False) + + def _generate_item_sets( + self, with_steps: bool, frequent_sets: Optional[List[tuple]] + ) -> List[tuple]: + """ + Generates (k+1)-item_sets from k-item_sets + Returns all found sets, not only strong ones + """ + if frequent_sets is None: + return [(item,) for item in self.columns.values] + + new_item_sets = [] + for frequent_set_1, frequent_set_2 in combinations(frequent_sets, 2): + if not ( + frequent_set_1[:-1] == frequent_set_2[:-1] + and frequent_set_1[-1] < frequent_set_2[-1] + ): + continue + + new_item_set = self.join(frequent_set_1, frequent_set_2) + + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.JOIN_AND_PRUNE, + "set_1": frequent_set_1, + "set_2": frequent_set_2, + "new_set": new_item_set, + "data_frame": self.k_frequent_sets_df, + } + ) + + if not self._has_infrequent_subsets( + with_steps, new_item_set, frequent_sets + ): + new_item_sets.append(new_item_set) + + return new_item_sets + + @staticmethod + def join(frequent_set_1: tuple, frequent_set_2: tuple) -> tuple: + return frequent_set_1 + (frequent_set_2[-1],) + + def _has_infrequent_subsets( + self, with_steps, new_frequent_set, prev_frequent_sets + ) -> bool: + for subset in combinations(new_frequent_set, len(prev_frequent_sets)): + if subset not in prev_frequent_sets: + if with_steps: + self.saved_steps[-1]["infrequent_subset"] = subset + return True + + if with_steps: + self.saved_steps[-1]["infrequent_subset"] = None + return False + + def support(self, item_set: tuple) -> float: + count = 0 + for transaction_set in self.transaction_sets: + if set(item_set).issubset(transaction_set): + count += 1 + return count / len(self.transaction_sets) + + def confidence(self, item_set_a: tuple, item_set_b: tuple) -> float: # a => b + return ( + self.all_frequent_sets[tuple(sorted(set(item_set_a) | set(item_set_b)))] + / self.all_frequent_sets[item_set_a] + ) + + @staticmethod + def get_all_subsets(item_set: tuple): + return chain.from_iterable( + combinations(item_set, i) for i in range(len(item_set) + 1) + ) + + def _get_association_rules(self, with_steps) -> pd.DataFrame: + rules = {} + for frequent_set in self.all_frequent_sets.keys(): + for subset_a in self.get_all_subsets(frequent_set): + subset_b = tuple(set(frequent_set) - set(subset_a)) + if not subset_a or not subset_b: + continue + if ( + confidence := self.confidence(subset_a, subset_b) + ) >= self.min_confidence: + rules[f"{format_set(subset_a)} => {format_set(subset_b)}"] = round( + confidence, 3 + ) + + if with_steps: + self.saved_steps.append( + { + "part": APrioriPartLabel.GENERATE_RULES, + "set": frequent_set, + "set_a": subset_a, + "set_b": subset_b, + "confidence": confidence, + "min_confidence": self.min_confidence, + "data_frame": self.k_frequent_sets_df, + } + ) + + return pd.DataFrame.from_dict( + rules, orient="index", columns=["confidence"] + ).sort_values(by="confidence", ascending=False) + + +class APrioriPartLabel(Enum): + CALCULATE_SUPPORT = "Calculating support" + FILTER_BY_SUPPORT = "Selecting frequent sets from generated and not already pruned" + SAVE_K_SETS = "Saving found k-frequent sets" + SAVE_RULES = "Saving found association rules" + JOIN_AND_PRUNE = "Joining sets and pruning ones with infrequent subsets" + GENERATE_RULES = "Generating and verifying potential rules from frequent sets" diff --git a/src/algorithms/classification/__init__.py b/src/algorithms/classification/__init__.py index e69de29..df90178 100644 --- a/src/algorithms/classification/__init__.py +++ b/src/algorithms/classification/__init__.py @@ -0,0 +1 @@ +from .extra_trees import ExtraTrees diff --git a/src/algorithms/classification/extra_trees.py b/src/algorithms/classification/extra_trees.py new file mode 100644 index 0000000..223d274 --- /dev/null +++ b/src/algorithms/classification/extra_trees.py @@ -0,0 +1,419 @@ +from collections import deque +from typing import Callable, Dict, List, Optional, Tuple + +import joblib +import matplotlib +import numpy as np +import pandas as pd + +from algorithms import Algorithm, check_numeric, get_threads_count + +metrics_types = ["gini", "entropy"] + + +class Leaf: + def __init__(self, data: pd.Series, info: Optional[List] = None): + self.prediction = data.value_counts().index[0] + self.samples = len(data) + self.info = info + + def graphviz_label(self, get_color: Callable) -> str: + samples_str = f"samples = {self.samples}" + class_str = f"class = {self.prediction}" + color_hex = get_color(self.prediction) + label_str = ( + f"[label=<{samples_str}
{class_str}>, " + f'fillcolor="{color_hex}", shape=circle]' + ) + return label_str + + +class Node: + def __init__(self, label: str, pivot: any, info: Optional[List] = None): + self.label = label + self.pivot = pivot + self.importance = 0 + self.samples = 0 + self.largest_class = None + self.left = None + self.right = None + self.info = info + + def graphviz_label(self, get_color: Callable) -> str: + if isinstance(self.pivot, bool) or isinstance(self.pivot, str): + pivot_str = f"{self.label} == {self.pivot}" + elif isinstance(self.pivot, pd.Series): + pivot_str = f"{self.label} in {self.pivot}" + elif check_numeric(self.pivot): + pivot_str = f"{self.label} > {self.pivot}" + else: + raise TypeError( + f"pivot must be bool, string, list or number not {type(self.pivot)}" + ) + samples_str = f"samples = {self.samples}" + class_str = f"class = {self.largest_class}" + color_hex = get_color(self.largest_class) + label_str = ( + f"[label=<{pivot_str}
{samples_str}
{class_str}>," + f' fillcolor="{color_hex}", shape=box]' + ) + return label_str + + def simple_graphviz_label(self, color: str) -> str: + samples_str = f"samples = {self.samples}" + class_str = f"class = {self.largest_class}" + label_str = ( + f"[label=<{samples_str}
{class_str}>, " + f'fillcolor="{color}", shape=box]' + ) + return label_str + + +class DecisionTree: + def __init__( + self, + data: pd.DataFrame, + with_steps: bool, + label_name: str, + features_number: int, + min_child_number: int, + max_depth: Optional[int], + min_metrics: float = 0.0, + metrics_type: metrics_types = "gini", + ): + self.data = data + self.with_steps = with_steps + self.label_name = label_name + self.features_number = features_number + self.min_child_number = min_child_number + self.max_depth = max_depth + self.min_metrics = min_metrics + self.metrics_type = metrics_type + self.features = pd.Series(data.columns.drop(self.label_name)) + self.root = self.create_node(pd.Series([True] * len(data), index=data.index), 1) + + def create_node(self, mask: pd.Series, depth) -> Node | Leaf: + if not self._can_be_split(mask, depth): + return Leaf(self.data[self.label_name][mask]) + features = self.features.sample(self.features_number) + best_metrics = None + left_mask = None + right_mask = None + best_feature = None + best_pivot = None + if self.with_steps: + choose_info = [] + else: + choose_info = None + for feature in features: + values = pd.Series(list(set(self.data[feature][mask]))) + if check_numeric(self.data[feature]): + pivot = values.sample(1).iloc[0] + else: + if len(values) <= 2: + pivot = values.sample(1).iloc[0] + else: + n = np.random.randint(1, len(values)) + pivot = values.sample(n) + result = self._split(mask, feature, pivot) + if result is None: + if self.with_steps: + choose_info.append([feature, pivot, None]) + continue + metrics, left, right = result + if self.with_steps: + choose_info.append([feature, pivot, round(metrics, 3)]) + if best_metrics is None or metrics > best_metrics: + best_metrics = metrics + left_mask = left + right_mask = right + best_feature = feature + best_pivot = pivot + + if best_metrics is None: + return Leaf(self.data[self.label_name][mask], choose_info) + + node = Node(best_feature, best_pivot, choose_info) + node.importance = best_metrics + node.left = self.create_node(left_mask, depth + 1) + node.right = self.create_node(right_mask, depth + 1) + node.samples = np.sum(mask) + node.largest_class = self.data[self.label_name][mask].value_counts().index[0] + return node + + def _can_be_split(self, mask: pd.Series, depth: int) -> bool: + if len(set(self.data[self.label_name][mask])) == 1: # pure node + return False + if self.max_depth is not None and self.max_depth <= depth: + return False + if np.sum(mask) < 2 * self.min_child_number: + return False + return True + + def _split( + self, mask: pd.Series, label: str, pivot: any + ) -> Optional[Tuple[float, pd.Series, pd.Series]]: + def _set_label(row: pd.Series): + if not mask[row.name]: + return 0 + value = row[label] + if isinstance(pivot, bool) or isinstance(pivot, str): + if value == pivot: + return 1 + elif isinstance(pivot, pd.Series): + if value in pivot: + return 1 + elif check_numeric(pivot): + if float(value) > pivot: + return 1 + else: + raise TypeError( + f"pivot must be bool, string, list or number not {type(pivot)}" + ) + return 2 + + division = self.data.apply(_set_label, axis="columns") + right_mask = division == 1 + left_mask = division == 2 + if ( + np.sum(left_mask) < self.min_child_number + or np.sum(right_mask) < self.min_child_number + ): + return None + + metrics = self._calculate_metrics(mask, left_mask, right_mask) + + if metrics < self.min_metrics: + return None + + return metrics, left_mask, right_mask + + def _calculate_metrics( + self, mask: pd.Series, left_mask: pd.Series, right_mask: pd.Series + ) -> float: + match self.metrics_type: + case "gini": + metrics_func = self._calculate_gini + case "entropy": + metrics_func = self._calculate_entropy + case _: + raise ValueError(f"'{self.metrics_type}' is not valid type of metrics") + parent = metrics_func(mask) + left = metrics_func(left_mask) + right = metrics_func(right_mask) + split = np.sum(left_mask) * left + np.sum(right_mask) * right + decrease = (np.sum(mask) * parent - split) / len(mask) + return decrease + + def _calculate_gini(self, mask: pd.Series): + labels = self.data[self.label_name][mask] + gini = 1 - np.sum((labels.value_counts() / len(labels)) ** 2) + return gini + + def _calculate_entropy(self, mask: pd.Series): + labels = self.data[self.label_name][mask] + prob = labels.value_counts() / len(labels) + entropy = -1 * np.sum(prob * np.log2(prob)) + return entropy + + def predict(self, record: pd.Series) -> pd.Series: + node = self.root + while not isinstance(node, Leaf): + node = self._next_node(record, node) + return node.prediction + + @staticmethod + def _next_node(record: pd.Series, node: Node) -> Node | Leaf: + value = record[node.label] + pivot = node.pivot + if isinstance(pivot, bool) or isinstance(pivot, str): + if value == pivot: + return node.right + elif isinstance(pivot, pd.Series): + if value in pivot: + return node.right + elif check_numeric(pivot): + if float(value) > pivot: + return node.right + else: + raise TypeError( + f"pivot must be bool, string, list or number not {type(pivot)}" + ) + return node.left + + def calculate_importance(self) -> pd.Series: + def add_value(node): + if isinstance(node, Leaf): + return + values[node.label] += node.importance + add_value(node.left) + add_value(node.right) + + values = pd.Series(0, index=self.features) + add_value(self.root) + values = values / np.sum(values) + return values + + def graphviz_str(self, get_color: Callable) -> Tuple[str, Dict, List]: + def make_next_row(num_from, node_next, side=None): + idx[0] += 1 + num_next = idx[0] + rows.append(f"{num_next} {node_next.graphviz_label(get_color)} ;") + if side is None: + rows.append(f"{num_from} -> {num_next} [width=3] ;") + elif side: + rows.append(f"{num_from} -> {num_next} [color=deepskyblue, width=3] ;") + else: + rows.append(f"{num_from} -> {num_next} [color=orange, width=3] ;") + if isinstance(node_next, Node): + make_next_row(num_next, node_next.left, False) + make_next_row(num_next, node_next.right, True) + + rows = [f"0 {self.root.graphviz_label(get_color)} ;"] + idx = [0] + if isinstance(self.root, Node): + make_next_row(0, self.root.left, False) + make_next_row(0, self.root.right, True) + rows_str = "\n".join(rows) + dot_str = ( + "digraph Tree {\n" + 'node [shape=box, style="filled, rounded", color="black", ' + 'fontname="helvetica"] ;\n' + f"{rows_str}\n" + "}" + ) + if len(rows) == 1: + creation = ({0: self.root.info}, [dot_str]) + else: + creation = self.creation_steps(get_color) + return dot_str, creation[0], creation[1] + + @staticmethod + def _make_string(rows: List, nodes: Dict) -> str: + rows_str = "\n".join(list(nodes.values()) + rows) + rows_str = rows_str.replace("shape=circle", "shape=ellipse") + dot_str = ( + "digraph Tree {\n" + 'node [shape=box, style="filled, rounded", ' + 'color="black", fontname="helvetica"] ;\n' + f"{rows_str}\n" + "}" + ) + return dot_str + + def creation_steps(self, get_color: Callable) -> Tuple[Dict, List[str]]: + steps = [] + nodes = {} + rows = [] + queue = deque() + idx = 0 + queue.append((self.root, idx)) + idx += 1 + creation_info = {} + while len(queue): + node, i = queue.popleft() + nodes[i] = f"{i} {node.simple_graphviz_label('blue')} ;" + steps.append(self._make_string(rows, nodes)) + creation_info[len(steps) - 1] = node.info + + nodes[i] = f"{i} {node.graphviz_label(get_color)} ;" + left = idx + right = idx + 1 + idx += 2 + rows.append(f"{i} -> {left} [color=orange, width=3] ;") + rows.append(f"{i} -> {right} [color=deepskyblue, width=3] ;") + if isinstance(node.left, Node): + nodes[left] = f"{left} {node.left.simple_graphviz_label('orange')} ;" + queue.append((node.left, left)) + else: + nodes[left] = f"{left} {node.left.graphviz_label(get_color)} ;" + if isinstance(node.right, Node): + nodes[right] = f"{right} {node.right.simple_graphviz_label('orange')} ;" + queue.append((node.right, right)) + else: + nodes[right] = f"{right} {node.right.graphviz_label(get_color)} ;" + steps.append(self._make_string(rows, nodes)) + return creation_info, steps + + +class ExtraTrees(Algorithm): + def __init__(self, data: pd.DataFrame, forest_size: int, **tree_parameters): + self.data = data.sample(frac=0.8) + self.test_data = data.drop(self.data.index) + self.forest_size = forest_size + self.forest = None + self.tree_parameters = tree_parameters + self.feature_importance = None + self.labels = np.array(list(set(self.data[tree_parameters["label_name"]]))) + + def get_config(self): + return [ + (column, self.data[column].dtype) + for column in self.data.columns.drop(self.tree_parameters["label_name"]) + ] + + @staticmethod + def _split_int_to_array(num: int, div: int) -> List[int]: + greater = num % div + lower = div - greater + value = num // div + return [value + 1] * greater + [value] * lower + + def run(self, with_steps: bool): + def do_job(num): + forest = [ + DecisionTree(self.data, with_steps=with_steps, **self.tree_parameters) + for _ in range(num) + ] + feature_importance = pd.DataFrame( + [tree.calculate_importance() for tree in forest] + ) + return forest, feature_importance + + threads_count = get_threads_count() + arr = self._split_int_to_array(self.forest_size, threads_count) + results = joblib.Parallel(n_jobs=threads_count)( + joblib.delayed(do_job)(num) for num in arr + ) + self.forest = sum([result[0] for result in results], []) + self.feature_importance = ( + pd.concat([result[1] for result in results]).sum(axis="index") + / self.forest_size + ) + self.feature_importance.sort_values(ascending=False, inplace=True) + self._update_metric() + return self.predict, self.get_config(), self.feature_importance + + def predict(self, record: pd.Series) -> any: + predictions = ( + pd.Series([tree.predict(record) for tree in self.forest]).value_counts() + / self.forest_size + ) + result = pd.Series(0, index=self.labels) + for idx, item in predictions.items(): + result[idx] = item + result.sort_values(ascending=False, inplace=True) + return result + + def get_steps(self) -> List: + return [tree.graphviz_str(self._get_color) for tree in self.forest] + + def _get_color(self, label: str) -> str: + normalize = matplotlib.colors.Normalize(vmin=0, vmax=len(self.labels)) + colormap = matplotlib.cm.get_cmap("gist_rainbow") + index = np.argwhere(self.labels == label)[0] + color = [min(1, 1.2 * c) for c in colormap(normalize(index))[0]] + return matplotlib.colors.to_hex(color) + + def _update_metric(self): + self.metrics_info = {} + if not len(self.test_data): + return + count_true = 0 + for _, row in self.test_data.iterrows(): + target = row[self.tree_parameters["label_name"]] + prediction = self.predict(row.drop(self.tree_parameters["label_name"])) + predicted = list(prediction.index)[0] + if target == predicted: + count_true += 1 + self.metrics_info["accuracy"] = round(count_true / len(self.test_data), 3) diff --git a/src/algorithms/clustering/__init__.py b/src/algorithms/clustering/__init__.py index e69de29..763cbc0 100644 --- a/src/algorithms/clustering/__init__.py +++ b/src/algorithms/clustering/__init__.py @@ -0,0 +1,2 @@ +from .gmm import GMM +from .k_means import KMeans diff --git a/src/algorithms/clustering/gmm.py b/src/algorithms/clustering/gmm.py new file mode 100644 index 0000000..7c6ac03 --- /dev/null +++ b/src/algorithms/clustering/gmm.py @@ -0,0 +1,153 @@ +from math import inf + +import numpy as np +import pandas as pd +from numpy.linalg import LinAlgError +from PyQt5.QtWidgets import QMessageBox +from scipy.stats import multivariate_normal + +from algorithms import Algorithm + +from .metrics import davies_bouldin_score, dunn_score, silhouette_score + + +class GMM(Algorithm): + def __init__(self, df, num_clusters, eps=1e-6, max_iterations=None): + self.df = df.select_dtypes(include=["number"]) + self.num_clusters = num_clusters + self.rows = self.df.shape[0] + self.dim = len(self.df.columns) + self.labels = np.floor_divide( + np.arange(self.rows), self.rows / self.num_clusters + ) + self.mu_arr = self.initialize_mu() + self.sigma_arr = self.initialize_sigma() + self.pi_arr = self.initialize_pi() + self.prob_matrix = None + self.eps = float(eps) + self.max_iter = max_iterations or inf + self.scenes = [ + (np.array(self.labels, dtype="int64"), self.mu_arr, self.sigma_arr) + ] + + def run(self, with_steps): + try: + prev_ll = self.log_likelihood() + i = 0 + while i < self.max_iter: + self.e_step() + self.m_step() + if with_steps: + self.scenes.append( + (self.get_cluster_labels(), self.mu_arr, self.sigma_arr) + ) + new_ll = self.log_likelihood() + if abs(new_ll - prev_ll) < self.eps: + break + prev_ll = new_ll + i += 1 + except LinAlgError: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText( + "Oops, singular matrix error. " + "This data may not be good for this algorithm. " + "Try to modify params or choose different algorithm." + ) + error.setWindowTitle("Error") + error.exec_() + return + labels = self.get_cluster_labels() + self.update_metrics(labels) + return ( + labels, + pd.DataFrame(self.mu_arr, columns=self.df.columns), + self.sigma_arr, + ) + + def get_cluster_labels(self): + return np.argmax(self.prob_matrix, axis=1) + + def get_steps(self): + return self.scenes + + def initialize_mu(self): + size = (self.num_clusters, self.dim) + return np.random.random_sample(size) + + def initialize_sigma(self): + return np.tile(np.eye(self.dim) * 100, (self.num_clusters, 1, 1)) + + def initialize_pi(self): + _, counts = np.unique(self.labels, return_counts=True) + return np.divide(counts, self.rows) + + def e_step(self): + self.prob_matrix = np.fromfunction( + np.vectorize(self.calculate_prob), (self.rows, self.num_clusters), dtype=int + ) + self.prob_matrix = self.prob_matrix / self.prob_matrix.sum( + axis=1, keepdims=True + ) + + def calculate_prob(self, i, c): + return self.pi_arr[c] * multivariate_normal.pdf( + self.df.iloc[i], mean=self.mu_arr[c], cov=self.sigma_arr[c] + ) + + def m_step(self): + self.pi_arr = np.divide(np.sum(self.prob_matrix, axis=0), self.rows) + self.update_mu() + self.update_sigma() + + def update_mu(self): + sums = [ + np.sum((self.df.T * self.prob_matrix[:, c]).T, axis=0) + for c in range(self.num_clusters) + ] + sums = [ + sums[c] / (self.pi_arr[c] * self.rows) for c in range(self.num_clusters) + ] + self.mu_arr = np.array(sums) + + def update_sigma(self): + for c in range(self.num_clusters): + new_sigma_component = np.zeros((self.dim, self.dim)) + for i in range(self.rows): + data_row = np.array(self.df.iloc[i]) - self.mu_arr[c] + new_sigma_component += self.prob_matrix[i, c] * np.outer( + data_row.T, data_row + ) + self.sigma_arr[c] = new_sigma_component / (self.pi_arr[c] * self.rows) + + def log_likelihood(self): + result = 0 + width = height = len(self.sigma_arr[0][0]) + covariance_regulator = 1e-4 * np.eye(width) + 1e-6 * np.ones((width, height)) + for i in range(self.rows): + row_result = 0 + for c in range(self.num_clusters): + cov_matrix = self.sigma_arr[c] + if not is_invertible(self.sigma_arr[c]): + cov_matrix += covariance_regulator + row_result += self.pi_arr[c] * multivariate_normal.pdf( + self.df.iloc[i], mean=self.mu_arr[c], cov=cov_matrix + ) + result += np.log(row_result) + return result + + def update_metrics(self, labels): + self.metrics_info = {} + d_index = dunn_score(self.df, labels) + db_index = davies_bouldin_score(self.df, labels) + s_index = silhouette_score(self.df, labels) + self.metrics_info["Dunn index (higher = better)"] = round(d_index, 3) + self.metrics_info["Davies Bouldin index (lower = better)"] = round(db_index, 3) + self.metrics_info["Silhouette Coefficient (higher = better)"] = round( + s_index, 3 + ) + + +def is_invertible(matrix): + width, height = matrix.shape + return width == height and np.linalg.matrix_rank(matrix) == width diff --git a/src/algorithms/clustering/k_means.py b/src/algorithms/clustering/k_means.py new file mode 100644 index 0000000..0c79ab2 --- /dev/null +++ b/src/algorithms/clustering/k_means.py @@ -0,0 +1,178 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd + +from algorithms import Algorithm + +from .metrics import davies_bouldin_score, dunn_score, silhouette_score + +init_types = ["random", "kmeans++"] + + +class KMeans(Algorithm): + def __init__( + self, + data: pd.DataFrame, + num_clusters: int, + metrics: int = 1, + iterations: Optional[int] = None, + repeats: int = 1, + init_type: init_types = "random", + ): + self.num_clusters = num_clusters + self.metrics = metrics + self.max_iterations = iterations + self.repeats = repeats + if init_type not in init_types: + raise TypeError(f"{init_type} is invalid value of init_type parameter") + self.step_counter = 0 + self.data = data.select_dtypes(include=["number"]) + self.centroids = [] + self.labels = np.zeros(self.data.shape[0], dtype=int) + self.saved_steps = [] + self.get_centroids = { + "random": self.random_centroids, + "kmeans++": self.kmeanspp_centroids, + }[init_type] + + def distance( + self, vector_x: Union[Tuple, List], vector_y: Union[Tuple, List] + ) -> float: + diff = np.abs(np.array(vector_x) - np.array(vector_y)) + return (np.sum(diff**self.metrics)) ** (1 / self.metrics) + + def random_centroids(self) -> List[Tuple]: + return list( + self.data.sample(self.num_clusters, replace=False).itertuples(index=False) + ) + + def kmeanspp_centroids(self) -> List[Tuple]: + centroids = list(self.data.sample(1, replace=False).itertuples(index=False)) + for _ in range(self.num_clusters - 1): + centroid = None + max_dis = 0 + for row in self.data.itertuples(index=False): + min_dis = np.inf + for cent in centroids: + dis = self.distance(row, cent) + if dis < min_dis: + min_dis = dis + if min_dis > max_dis: + max_dis = min_dis + centroid = row + centroids.append(centroid) + return centroids + + def mark_labels(self) -> int: + count = 0 + for i, row in enumerate(self.data.itertuples(index=False)): + min_dis = np.inf + m = 0 + for j, centroid in enumerate(self.centroids): + dis = self.distance(row, centroid) + if dis < min_dis: + min_dis = dis + m = j + if self.labels[i] != m: + count += 1 + self.labels[i] = m + return count + + @staticmethod + def mean(group: pd.DataFrame) -> Tuple: + return tuple(group.mean(axis="index")) + + def update_centroids(self): + for i, centroid in enumerate(self.centroids): + group = self.data[self.labels == i] + self.centroids[i] = self.mean(group) + + def step(self) -> bool: + self.update_centroids() + count = self.mark_labels() + if count == 0: + return False + return True + + def update_metrics(self, labels): + self.metrics_info = {} + d_index = dunn_score(self.data, labels) + db_index = davies_bouldin_score(self.data, labels) + s_index = silhouette_score(self.data, labels) + self.metrics_info["Dunn index (higher = better)"] = round(d_index, 3) + self.metrics_info["Davies Bouldin index (lower = better)"] = round(db_index, 3) + self.metrics_info["Silhouette Coefficient (higher = better)"] = round( + s_index, 3 + ) + + def check_solution(self, labels): + """dunn index""" + return dunn_score(self.data, labels, self.distance) + + def run(self, with_steps) -> Tuple[np.ndarray, pd.DataFrame]: + runner = ( + self.run_with_saving_steps if with_steps else self.run_without_saving_steps + ) + if self.repeats == 1: + solution = runner() + self.update_metrics(solution[0]) + return solution[0], pd.DataFrame(solution[1], columns=self.data.columns) + best_value = 0 + solution = None + steps = None + for _ in range(self.repeats): + result = runner() + value = self.check_solution(result[0]) + if value > best_value: + solution = (result[0].copy(), result[1].copy()) + steps = [(step[0].copy(), step[1].copy()) for step in self.saved_steps] + best_value = value + self.saved_steps = steps + self.update_metrics(solution[0]) + return solution[0], pd.DataFrame(solution[1], columns=self.data.columns) + + def run_with_saving_steps(self) -> Tuple[np.ndarray, List[Tuple]]: + steps = 0 + self.saved_steps = [] + self.centroids = self.get_centroids() + self.mark_labels() + self.saved_steps.append( + ( + self.labels.copy(), + pd.DataFrame(self.centroids, columns=self.data.columns), + ) + ) + while self.step(): + steps += 1 + self.saved_steps.append( + ( + self.labels.copy(), + pd.DataFrame(self.centroids, columns=self.data.columns), + ) + ) + if self.max_iterations and steps > self.max_iterations: + break + self.step_counter = steps + self.saved_steps.append( + ( + self.labels.copy(), + pd.DataFrame(self.centroids, columns=self.data.columns), + ) + ) + return self.labels.copy(), self.centroids + + def run_without_saving_steps(self) -> Tuple[np.ndarray, List[Tuple]]: + steps = 0 + self.saved_steps = [] + self.centroids = self.get_centroids() + self.mark_labels() + while self.step(): + steps += 1 + if self.max_iterations is not None and steps > self.max_iterations: + break + self.step_counter = steps + return self.labels, self.centroids + + def get_steps(self) -> List[Tuple[np.ndarray, pd.DataFrame]]: + return self.saved_steps diff --git a/src/algorithms/clustering/metrics.py b/src/algorithms/clustering/metrics.py new file mode 100644 index 0000000..5efa8e6 --- /dev/null +++ b/src/algorithms/clustering/metrics.py @@ -0,0 +1,100 @@ +import numpy as np +import pandas as pd + + +def distance(x, y): + return np.linalg.norm(x - y) + + +def davies_bouldin_score(df: pd.DataFrame, labels: np.array): + clusters = np.unique(labels) + num_cluster = len(clusters) + centroids = [] + avg_distances = [] + for cluster in clusters: + cluster_set = df.loc[labels == cluster].to_numpy() + centroid = cluster_set.mean(axis=0) + avg_distance = np.linalg.norm( + cluster_set - np.tile(centroid, (len(cluster_set), 1)), axis=1 + ).mean() + centroids.append(centroid) + avg_distances.append(avg_distance) + R_matrix = np.zeros((num_cluster, num_cluster), dtype=np.float) + for i in range(num_cluster): + x = centroids[i] + for j in range(i + 1, num_cluster): + y = centroids[j] + avg_distance = avg_distances[i] + avg_distances[j] + centroids_dis = distance(x, y) + if avg_distance == 0: + R_matrix[i, j] = R_matrix[j, i] = 0 + continue + if centroids_dis == 0: + return np.inf + R_matrix[i, j] = R_matrix[j, i] = avg_distance / centroids_dis + return np.amax(R_matrix, axis=0).mean() + + +def dunn_score(df: pd.DataFrame, labels: np.array, distance_function=distance): + clusters = np.unique(labels) + centroids = [] + for cluster in clusters: + cluster_set = df.loc[labels == cluster].to_numpy() + centroid = cluster_set.mean(axis=0) + centroids.append(centroid) + max_distance_intra = 0 + min_distance_inter = np.inf + for i, first in enumerate(centroids): + for second in centroids[i + 1 :]: + dis = distance_function(first, second) + min_distance_inter = min(min_distance_inter, dis) + for cluster in clusters: + cluster_df = df.loc[labels == cluster] + size = len(cluster_df) + for i, (_, pointA) in enumerate(cluster_df.iterrows()): + if i + 1 == size: + continue + dis = np.linalg.norm( + cluster_df.iloc[i + 1 :].to_numpy() + - np.tile(pointA.to_numpy(), (size - i - 1, 1)), + axis=1, + ).max() + max_distance_intra = max(max_distance_intra, dis) + if min_distance_inter == 0: + return 0 + if max_distance_intra == 0: + return np.inf + return min_distance_inter / max_distance_intra + + +def silhouette_score(df: pd.DataFrame, labels: np.array): + clusters = np.unique(labels) + dis_same_class = np.zeros_like(labels, dtype=np.float) + dis_next_class = np.zeros_like(labels, dtype=np.float) + for i, row in df.iterrows(): + cluster = labels[i] + same_cluster_df = df.loc[labels == cluster].drop(index=i).to_numpy() + size = len(same_cluster_df) + dis_same_class[i] = ( + np.linalg.norm( + same_cluster_df - np.tile(row.to_numpy(), (size, 1)), axis=1 + ).mean() + if size + else 0 + ) + + min_dis = np.inf + for j in clusters: + if cluster == j: + continue + cluster_df = df.loc[labels == j].to_numpy() + dis = np.linalg.norm( + cluster_df - np.tile(row.to_numpy(), (len(cluster_df), 1)), axis=1 + ).mean() + min_dis = min(min_dis, dis) + dis_next_class[i] = min_dis + denominator = np.maximum(dis_same_class, dis_next_class) + if 0 in denominator: + return np.inf + silhouette_per_sample = (dis_next_class - dis_same_class) / denominator + return silhouette_per_sample.mean() diff --git a/src/algorithms/generate_number.py b/src/algorithms/generate_number.py deleted file mode 100644 index 0b78c81..0000000 --- a/src/algorithms/generate_number.py +++ /dev/null @@ -1,12 +0,0 @@ -import numpy as np - - -class NumberGenerator: - - def __init__(self, start: int, end: int): - self.start = start - self.end = end - - def get_number(self) -> float: - value = np.random.random() * (self.end - self.start) + self.start - return round(value, 2) diff --git a/src/algorithms/utils.py b/src/algorithms/utils.py new file mode 100644 index 0000000..242fc19 --- /dev/null +++ b/src/algorithms/utils.py @@ -0,0 +1,21 @@ +import os +from typing import List + +import numpy as np +import pandas as pd + + +def get_samples(arr_size, num_samples) -> List: + return np.random.choice(arr_size, num_samples, replace=False) + + +def check_numeric(element: any) -> bool: + try: + pd.to_numeric(element) + return True + except (ValueError, TypeError): + return False + + +def get_threads_count(): + return os.cpu_count() - 2 diff --git a/src/app.py b/src/app.py index 2a890bf..fc98452 100644 --- a/src/app.py +++ b/src/app.py @@ -1,14 +1,33 @@ import sys + +import matplotlib as plt from PyQt5.QtWidgets import QApplication -from widgets.main_layout import RandomGenerator +from engines import ( + AlgorithmsEngine, + ImportDataEngine, + PreprocessingEngine, + ResultsEngine, +) +from state import State +from widgets import MainWindow def main(): app = QApplication(sys.argv) - window = RandomGenerator() + state = State() + algorithm_engine = AlgorithmsEngine(state) + engines = { + "import_data": ImportDataEngine(state), + "preprocess": PreprocessingEngine(state), + "algorithm_setup": algorithm_engine, + "algorithm_run": algorithm_engine, + "results": ResultsEngine(state), + } + plt.rcParams.update({"font.size": 7}) + window = MainWindow(engines) sys.exit(app.exec_()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/data_generators/__init__.py b/src/data_generators/__init__.py new file mode 100644 index 0000000..2121c0d --- /dev/null +++ b/src/data_generators/__init__.py @@ -0,0 +1,10 @@ +from typing import Callable, Dict, TypeAlias + +import pandas as pd + +from .clustering_data_generator import ( + noncentral_f_blobs_generator, + normal_distribution_blobs_generator, +) + +DataGeneratorFunction: TypeAlias = Callable[[Dict], pd.DataFrame] diff --git a/src/data_generators/clustering_data_generator.py b/src/data_generators/clustering_data_generator.py new file mode 100644 index 0000000..828bd46 --- /dev/null +++ b/src/data_generators/clustering_data_generator.py @@ -0,0 +1,81 @@ +from typing import Dict + +import numpy as np +import pandas as pd + +SCALING_FACTOR = 50 + + +def normal_distribution_blobs_generator(options: Dict) -> pd.DataFrame: + blobs_number = options["blobs_number"] + sample_sizes = options["sample_sizes"] + + dims_number = options["dims_number"] + blobs_dims_stds = options["dims_stds"] + + seed = options.get("seed") + np.random.seed(seed) + + noise_percentage = options["noise"] + + centers = np.random.rand(blobs_number, dims_number) * SCALING_FACTOR + + data = pd.DataFrame() + + for blob_i, (sample_size, center, dims_stds) in enumerate( + zip(sample_sizes, centers, blobs_dims_stds), start=1 + ): + noise_count = round(sample_size * noise_percentage) + blob_data = pd.DataFrame() + normal_data = np.random.multivariate_normal( + center, np.diag(np.array(dims_stds)), sample_size + ).T + + for dim_i, normal_data_i in enumerate(normal_data, start=1): + blob_data[f"Dim #{dim_i}"] = np.concatenate( + [ + normal_data_i, + np.random.rand(noise_count) * SCALING_FACTOR, + ] + ) + blob_data["Blob number"] = np.full(sample_size + noise_count, fill_value=blob_i) + data = pd.concat([data, blob_data]) + + return data.sample(frac=1).reset_index(drop=True).round(3) + + +def noncentral_f_blobs_generator(options: Dict) -> pd.DataFrame: + blobs_number = options["blobs_number"] + sample_sizes = options["sample_sizes"] + + dims_number = options["dims_number"] + + blobs_dfnums = options["df_nums"] + blobs_dfdens = options["df_dens"] + + seed = options.get("seed") + np.random.seed(seed) + + noise_percentage = options["noise"] + + centers = np.random.rand(blobs_number, dims_number) * SCALING_FACTOR + + data = pd.DataFrame() + + for blob_i, (sample_size, center, df_num, df_den) in enumerate( + zip(sample_sizes, centers, blobs_dfnums, blobs_dfdens), start=1 + ): + noise_count = round(sample_size * noise_percentage) + blob_data = pd.DataFrame() + for dim_i, center_loc in enumerate(center, start=1): + blob_data[f"Dim #{dim_i}"] = np.concatenate( + [ + np.random.noncentral_f(df_num, df_den, center_loc, size=sample_size) + + center_loc, + np.random.rand(noise_count) * SCALING_FACTOR, + ] + ) + blob_data["Blob number"] = np.full(sample_size + noise_count, fill_value=blob_i) + data = pd.concat([data, blob_data]) + + return data.sample(frac=1).reset_index(drop=True).round(3) diff --git a/src/data_import/__init__.py b/src/data_import/__init__.py new file mode 100644 index 0000000..39d2f12 --- /dev/null +++ b/src/data_import/__init__.py @@ -0,0 +1,6 @@ +from .config import AVAILABLE_RAM_MEMORY, SIZE_OF_VALUE +from .file_reader import FileReader +from .csv_reader import CSVReader +from .database_reader import DatabaseReader +from .json_reader import JSONReader +from .loader import Loader diff --git a/src/data_import/config.py b/src/data_import/config.py new file mode 100644 index 0000000..37fab29 --- /dev/null +++ b/src/data_import/config.py @@ -0,0 +1,10 @@ +import psutil + +PERCENT_TO_USE = 0.8 +# available memory in bytes +AVAILABLE_RAM_MEMORY = int( + psutil.virtual_memory().available * PERCENT_TO_USE +) # / (1024**2) + +# average size of element in data frame +SIZE_OF_VALUE = 8 * 20 diff --git a/src/data_import/csv_reader.py b/src/data_import/csv_reader.py new file mode 100644 index 0000000..7ae5be8 --- /dev/null +++ b/src/data_import/csv_reader.py @@ -0,0 +1,40 @@ +from typing import List, Optional + +import pandas as pd + +from data_import import FileReader + + +class CSVReader(FileReader): + def __init__(self, filepath: str): + try: + super().__init__(filepath) + self.columns_name = list(pd.read_csv(self.filepath, nrows=1).columns) + except FileNotFoundError: + self.error = ( + f"This filepath: {filepath} is invalid. Please write correct path." + ) + except Exception: + self.error = "There is some problem with file. Please try again." + self.reader = None + + # return DataFrame or TextFileReader (can use as generator of DataFrame) + def read(self, columns: Optional[List[str]]): + if self.need_chunks: + self._read_by_chunks(columns) + else: + self._read_all(columns) + return self.reader + + def _read_by_chunks(self, columns: Optional[List[str]]): + chunksize = self.get_chunksize() + self.reader = pd.read_csv( + self.filepath, + usecols=columns, + engine="c", + low_memory=True, + chunksize=chunksize, + ) + + def _read_all(self, columns: List[str]): + self.reader = pd.read_csv(self.filepath, usecols=columns, engine="c") diff --git a/src/data_import/database_reader.py b/src/data_import/database_reader.py new file mode 100644 index 0000000..10d6130 --- /dev/null +++ b/src/data_import/database_reader.py @@ -0,0 +1,66 @@ +from typing import Generator, List, Optional, Union + +import pandas as pd + +from data_import import AVAILABLE_RAM_MEMORY, SIZE_OF_VALUE +from database import Reader + + +class DatabaseReader: + def __init__(self, db_name: str, coll_name: str): + """ + Class to read data from database. + self.reader is DataFrame or Generator of DataFrame. + We may implement some special class + to have data and behave as DataFrame. + """ + self.error = "" + self.need_chunks = False + try: + self.database = Reader(db_name, coll_name) + self.columns_name = self.database.get_columns_names() + + # check size - big data support is not ready + # size = self.database.get_rows_number() + # self.need_chunks = size > self.get_chunksize() + except Exception as e: + print(e) + self.error = "There is some problem with database. Please try again." + self.reader = None + + def get_columns_name(self) -> List[str]: + return self.columns_name + + def get_error(self) -> str: + return self.error + + # approximate size of chunk, we want using ram as good as possible + def get_chunksize(self) -> int: + return AVAILABLE_RAM_MEMORY // (len(self.columns_name) * SIZE_OF_VALUE) + + def is_file_big(self) -> bool: + return self.need_chunks + + def read( + self, columns: Optional[List[str]] + ) -> Union[pd.DataFrame, Generator[pd.DataFrame, None, None]]: + if columns is None: + columns = self.columns_name + if self.need_chunks: + self.reader = self._read_by_chunks(columns) + else: + self._read_all(columns) + return self.reader + + def _read_by_chunks(self, columns: [List[str]]): + chunksize = self.get_chunksize() + chunk_num = 0 + chunks = self.database.get_rows_number() // chunksize + while chunk_num <= chunks: + yield self.database.get_nth_chunk( + columns=columns, chunk_size=chunksize, chunk_number=chunk_num + ) + chunk_num += 1 + + def _read_all(self, columns: List[str]): + self.reader = pd.DataFrame(self.database.execute_query(columns=columns)) diff --git a/src/data_import/file_reader.py b/src/data_import/file_reader.py new file mode 100644 index 0000000..08c7cea --- /dev/null +++ b/src/data_import/file_reader.py @@ -0,0 +1,28 @@ +import os +from typing import List + +from data_import import AVAILABLE_RAM_MEMORY, SIZE_OF_VALUE + + +class FileReader: + def __init__(self, filepath: str): + self.error = "" + self.filepath = filepath + self.need_chunks = False + + # check size of file - big data support is not ready + # size = os.stat(self.filepath).st_size + # self.need_chunks = size > AVAILABLE_RAM_MEMORY + + def get_columns_name(self) -> List[str]: + return self.columns_name + + # approximate size of chunk, we want using ram as good as possible + def get_chunksize(self) -> int: + return AVAILABLE_RAM_MEMORY // (len(self.columns_name) * SIZE_OF_VALUE) + + def get_error(self) -> str: + return self.error + + def is_file_big(self): + return self.need_chunks diff --git a/src/data_import/json_reader.py b/src/data_import/json_reader.py new file mode 100644 index 0000000..5da84a2 --- /dev/null +++ b/src/data_import/json_reader.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +import pandas as pd + +from data_import import FileReader + + +class JSONReader(FileReader): + def __init__(self, filepath: str): + try: + super().__init__(filepath) + self.columns_name = list(pd.read_json(self.filepath).columns) + + # if file is big we can not read by chunks because of .json format + if self.need_chunks: + self.error = "File is to big for parsing in .json format" + except FileNotFoundError: + self.error = ( + f"This filepath: {filepath} is invalid. Please write correct path." + ) + except Exception: + self.error = "There is some problem with file. Please try again." + self.reader = None + + def read(self, columns: Optional[List[str]]): + self._read_all(columns) + return self.reader + + def _read_all(self, columns: Optional[List[str]]): + if columns is None: + self.reader = pd.read_json(self.filepath, typ="frame") + self.reader = pd.read_json(self.filepath, typ="frame").filter( + items=columns, axis="columns" + ) diff --git a/src/data_import/loader.py b/src/data_import/loader.py new file mode 100644 index 0000000..44487fa --- /dev/null +++ b/src/data_import/loader.py @@ -0,0 +1,30 @@ +from data_import import CSVReader, JSONReader, DatabaseReader +from engines import DB_NAME + + +class Loader: + def __init__(self): + pass + + def create_file_reader(self, file_path): + reader = None + if not file_path: + raise ValueError("") + if "." not in file_path: + raise ValueError("Supported file format: .csv, .json.") + extension = file_path.split(".")[-1] + if extension == "csv": + reader = CSVReader(file_path) + elif extension == "json": + reader = JSONReader(file_path) + else: + raise ValueError("Supported file format: .csv, .json.") + if error := reader.get_error(): + raise ValueError(error) + return reader + + def create_database_reader(self, document): + reader = DatabaseReader(DB_NAME, document) + if error := reader.get_error(): + raise ValueError(error) + return reader diff --git a/src/database/__init__.py b/src/database/__init__.py index e69de29..8909e50 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -0,0 +1,5 @@ +from .database_manager import DatabaseObjectManager +from .database_data_remover import DocumentRemover +from .database_data_updater import DocumentUpdater +from .database_reader import Reader +from .database_writer import Writer diff --git a/src/database/config.py b/src/database/config.py new file mode 100644 index 0000000..28dacff --- /dev/null +++ b/src/database/config.py @@ -0,0 +1,14 @@ +import os + +# fixes database connection on my wifi, to delete if it interferes with other networks +import dns.resolver +from pymongo import MongoClient + +dns.resolver.default_resolver = dns.resolver.Resolver(configure=False) +dns.resolver.default_resolver.nameservers = ["8.8.8.8"] + +client = MongoClient( + "mongodb+srv://admin:{}@dataminingtooldb.trcgm.mongodb.net/".format( + os.environ.get("MONGO_PASS") + ) +) diff --git a/src/database/database_data_remover.py b/src/database/database_data_remover.py new file mode 100644 index 0000000..fea0a6c --- /dev/null +++ b/src/database/database_data_remover.py @@ -0,0 +1,16 @@ +from database import DatabaseObjectManager + + +class DocumentRemover: + def __init__(self, db_name, coll_name): + self.db_manager = DatabaseObjectManager() + self.db = self.db_manager.get_database(db_name) + self.collection = self.db_manager.get_collection(db_name, coll_name) + + def query_remove(self, query): + """Remove all elements that passed the query""" + return self.collection.delete_many(query) + + def remove_all(self): + """Clear a collection""" + return self.collection.delete_many({}) diff --git a/src/database/database_data_updater.py b/src/database/database_data_updater.py new file mode 100644 index 0000000..f84018a --- /dev/null +++ b/src/database/database_data_updater.py @@ -0,0 +1,13 @@ +from database import DatabaseObjectManager + + +class DocumentUpdater: + def __init__(self, db_name, coll_name): + self.db_manager = DatabaseObjectManager() + self.db = self.db_manager.get_database(db_name) + self.collection = self.db_manager.get_collection(db_name, coll_name) + + def query_update(self, query, new_values): + """Update all queried records with values from new_values dictionary""" + updated = {"$set": new_values} + return self.collection.update_many(query, updated) diff --git a/src/database/database_manager.py b/src/database/database_manager.py new file mode 100644 index 0000000..0b48779 --- /dev/null +++ b/src/database/database_manager.py @@ -0,0 +1,42 @@ +from .config import client + + +class DatabaseObjectManager: + def __init__(self): + self.db_client = client + + def get_database(self, db_name): + """Get database by provided name or create new one if it not exists""" + return self.db_client[db_name] + + def get_databases_list(self): + """Get list of all databases""" + return self.db_client.list_database_names() + + def find_database(self, db_name): + """Check if database with this name is in system""" + db_list = self.db_client.list_database_names() + return db_name in db_list + + def remove_database(self, db_name): + """Remove unwanted database""" + self.db_client.drop_database(db_name) + + def get_collection(self, db_name, collection_name): + """Get collection by provided name from specified database + or create new one if it not exists""" + return self.db_client[db_name][collection_name] + + def get_collections_list(self, db_name): + """Get list of all collections in the database""" + return self.db_client[db_name].list_collection_names() + + def find_collection(self, db_name, coll_name): + """Check if collection with this name is in the database""" + db = self.db_client[db_name] + coll_list = db.list_collection_names() + return coll_name in coll_list + + def remove_collection(self, db_name, coll_name): + """Remove unwanted collection""" + self.db_client[db_name][coll_name].drop() diff --git a/src/database/database_reader.py b/src/database/database_reader.py new file mode 100644 index 0000000..224ccc9 --- /dev/null +++ b/src/database/database_reader.py @@ -0,0 +1,63 @@ +import pandas as pd + +from database import DatabaseObjectManager + + +class Reader: + def __init__(self, db_name, coll_name): + self.db_manager = DatabaseObjectManager() + self.db = self.db_manager.get_database(db_name) + self.collection = self.db_manager.get_collection(db_name, coll_name) + + def execute_query(self, query=None, columns=None, use_id=0, limit=0): + """Make a query for specified collection and return result as a list""" + if columns is None: + columns = self.get_columns_names() + if query is None: + query = {} + fields_selected = {} + for name in columns: + fields_selected[name] = 1 + fields_selected["_id"] = use_id + return list( + self.collection.find(query, fields_selected).limit(limit) + ) # maybe changed to another format + + def get_nth_chunk( + self, query=None, columns=None, use_id=0, chunk_size=0, chunk_number=0 + ): + """Returns a n-th chunk of data from database, + chunks are indexed from 0""" + if columns is None: + columns = self.get_columns_names() + if query is None: + query = {} + fields_selected = {} + for name in columns: + fields_selected[name] = 1 + fields_selected["_id"] = use_id + chunk = ( + self.collection.find(query, fields_selected) + .skip(chunk_size * chunk_number) + .limit(chunk_size) + ) + return pd.DataFrame(list(chunk)) + + def get_rows_number(self): + return self.collection.count_documents({}) + + def get_columns_names(self): + result = self.collection.aggregate( + [ + {"$project": {"arrayofkeyvalue": {"$objectToArray": "$$ROOT"}}}, + {"$unwind": "$arrayofkeyvalue"}, + { + "$group": { + "_id": "null", + "allkeys": {"$addToSet": "$arrayofkeyvalue.k"}, + } + }, + ] + ).next()["allkeys"] + result.remove("_id") + return result diff --git a/src/database/database_writer.py b/src/database/database_writer.py new file mode 100644 index 0000000..63d05b6 --- /dev/null +++ b/src/database/database_writer.py @@ -0,0 +1,18 @@ +from database import DatabaseObjectManager + + +class Writer: + def __init__(self, db_name, coll_name): + self.db_manager = DatabaseObjectManager() + self.db = self.db_manager.get_database(db_name) + self.collection = self.db_manager.get_collection(db_name, coll_name) + + def add_document(self, record): + """Takes dictionary and adds it to collection, + returns id of inserted object""" + return self.collection.insert_one(record) + + def add_dataset(self, dataframe): + """Takes dataframe and adds it to collection, + returns list of objects ids""" + return self.collection.insert_many(dataframe.to_dict("records")) diff --git a/src/engines/__init__.py b/src/engines/__init__.py new file mode 100644 index 0000000..f2579b1 --- /dev/null +++ b/src/engines/__init__.py @@ -0,0 +1,5 @@ +from .config import DB_NAME +from .import_data_engine import ImportDataEngine +from .preprocessing_engine import PreprocessingEngine +from .algorithms_engine import AlgorithmsEngine +from .results_engine import ResultsEngine diff --git a/src/engines/algorithms_config.py b/src/engines/algorithms_config.py new file mode 100644 index 0000000..6e4bfec --- /dev/null +++ b/src/engines/algorithms_config.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Type + +from algorithms import Algorithm +from algorithms.associations import APriori +from algorithms.classification import ExtraTrees +from algorithms.clustering import GMM, KMeans +from widgets.options_widgets import ( + AlgorithmOptions, + AssociationRulesOptions, + ExtraTreesOptions, + GMMOptions, + KMeansOptions, +) +from widgets.results_widgets import ( + AlgorithmResultsWidget, + APrioriResultsWidget, + ExtraTreesResultsWidget, + GMMResultsWidget, + KMeansResultsWidget, +) +from widgets.steps_widgets import ( + AlgorithmStepsVisualization, + APrioriStepsVisualization, + ExtraTreesStepsVisualization, + GMMStepsVisualization, + KMeansStepsVisualization, +) + + +class AlgorithmTechniques(Enum): + CLUSTERING = "clustering" + ASSOCIATIONS = "associations" + CLASSIFICATION = "classification" + + @classmethod + def list(cls) -> List[str]: + return list(map(lambda e: e.value, cls)) + + +@dataclass +class AlgorithmConfig: + algorithm: Type[Algorithm] + options: Type[AlgorithmOptions] + steps_visualization: Type[AlgorithmStepsVisualization] + result_widget: Type[AlgorithmResultsWidget] + description: str = "" + + +descriptions = { + "K-Means": """ + The K-Means algorithm clusters data by trying to separate samples in n groups of equal variance, minimizing distance to the centers of clusters. + This algorithm requires the number of clusters to be specified and works only on numeric data. It scales well to large numbers of samples and has been used across a large range of application areas in many different fields. + The K-Means algorithm divides a set of samples into disjoint clusters, each described by the mean of the samples in the cluster. The means are commonly called the cluster “centroids”. + """, + "Gaussian Mixture Models": """ + A Gaussian mixture model is a probabilistic model that assumes all the data points are generated from a mixture of a finite number of Gaussian distributions with unknown parameters. + This algorithm requires the number of clusters to be specified and works only on numeric data. + One can think of mixture models as generalizing k-means clustering to incorporate information about the covariance structure of the data as well as the centers of the latent Gaussians. + """, + "Apriori": """ + Apriori Algorithm is a Machine Learning algorithm which is used to gain insight into the structured relationships between different items involved. The most prominent practical application of the algorithm is to recommend products based on the products already present in the user’s cart. + This algorithm needs table in special formats. The columns are products and the rows are receipts. Values in table describe number of the product in the receipt. + """, + "Extra Trees": """ + The Extremely Randomized Trees is forest of the decision trees. Decision trees are a non-parametric supervised learning method used for classification. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. + In Extra Trees create process is fully random. The random subset of candidate features is used in the split step and thresholds are drawn at random for each candidate feature and the best of these randomly-generated thresholds is picked as the splitting rule. + This usually allows to reduce the variance of the model a bit more, at the expense of a slightly greater increase in bias. + """, +} + + +def preprocess_description(description: str): + return "\n".join([line.strip() for line in description.split("\n")]).strip() + + +ALGORITHMS_INFO: Dict[str, Dict[str, AlgorithmConfig]] = { + AlgorithmTechniques.CLUSTERING.value: { + "K-Means": AlgorithmConfig( + algorithm=KMeans, + options=KMeansOptions, + steps_visualization=KMeansStepsVisualization, + result_widget=KMeansResultsWidget, + description=preprocess_description(descriptions["K-Means"]), + ), + "Gaussian Mixture Models": AlgorithmConfig( + algorithm=GMM, + options=GMMOptions, + steps_visualization=GMMStepsVisualization, + result_widget=GMMResultsWidget, + description=preprocess_description(descriptions["Gaussian Mixture Models"]), + ), + }, + AlgorithmTechniques.ASSOCIATIONS.value: { + "Apriori": AlgorithmConfig( + algorithm=APriori, + options=AssociationRulesOptions, + steps_visualization=APrioriStepsVisualization, + result_widget=APrioriResultsWidget, + description=preprocess_description(descriptions["Apriori"]), + ) + }, + AlgorithmTechniques.CLASSIFICATION.value: { + "Extra Trees": AlgorithmConfig( + algorithm=ExtraTrees, + options=ExtraTreesOptions, + steps_visualization=ExtraTreesStepsVisualization, + result_widget=ExtraTreesResultsWidget, + description=preprocess_description(descriptions["Extra Trees"]), + ) + }, +} diff --git a/src/engines/algorithms_engine.py b/src/engines/algorithms_engine.py new file mode 100644 index 0000000..e956cd5 --- /dev/null +++ b/src/engines/algorithms_engine.py @@ -0,0 +1,94 @@ +from typing import List + +from state import State +from widgets.options_widgets import AlgorithmOptions + +from .algorithms_config import ALGORITHMS_INFO, AlgorithmTechniques + + +class AlgorithmsEngine: + def __init__(self, state: State): + self.state = state + + # init options widgets + self.options = {} + for technique, info in ALGORITHMS_INFO.items(): + self.options[technique] = { + algorithm: classes.options() for algorithm, classes in info.items() + } + + def run( + self, technique, algorithm, will_be_visualized, is_animation, **kwargs + ) -> bool: + chosen_alg = ALGORITHMS_INFO[technique][algorithm] + + alg = chosen_alg.algorithm(self.state.imported_data, **kwargs) + + result = alg.run(will_be_visualized) + metrics_info = alg.metrics_info + + if result is None: + return False + + if will_be_visualized: + steps = alg.get_steps() + self.state.steps_visualization = chosen_alg.steps_visualization( + self.state.imported_data, steps, is_animation + ) + else: + self.state.steps_visualization = None + + # create a widget for the results + self.state.last_algorithm = (technique, algorithm) + if not self.state.algorithm_results_widgets.get(technique): + self.state.algorithm_results_widgets[technique] = {} + if not self.state.algorithm_results_widgets[technique].get(algorithm): + self.state.algorithm_results_widgets[technique][algorithm] = [] + self.state.algorithm_results_widgets[technique][algorithm].append( + chosen_alg.result_widget( + self.state.raw_data, *result, options=kwargs, metrics_info=metrics_info + ) + ) + return True + + def get_maximum_clusters(self) -> int: + if self.state.imported_data is None: + return 100 + return self.state.imported_data.shape[0] + + def get_columns(self) -> List: + return list(self.state.imported_data.columns) + + @staticmethod + def get_all_techniques() -> List: + return AlgorithmTechniques.list() + + @staticmethod + def get_algorithms_for_techniques(technique: AlgorithmTechniques.list()) -> List: + return list(ALGORITHMS_INFO[technique].keys()) + + def get_option_widget( + self, technique: AlgorithmTechniques.list(), algorithm: str + ) -> AlgorithmOptions: + return self.options[technique][algorithm] + + def update_options(self): + clusters = min(self.get_maximum_clusters(), 100) + self.options[AlgorithmTechniques.CLUSTERING.value]["K-Means"].set_max_clusters( + clusters + ) + columns = self.get_columns() + self.options[AlgorithmTechniques.ASSOCIATIONS.value][ + "Apriori" + ].set_columns_options(columns) + self.options[AlgorithmTechniques.CLASSIFICATION.value][ + "Extra Trees" + ].set_values(columns) + self.options[AlgorithmTechniques.CLUSTERING.value][ + "Gaussian Mixture Models" + ].set_max_clusters(clusters) + + def get_algorithm_description( + self, technique: AlgorithmTechniques.list(), algorithm: str + ) -> str: + return ALGORITHMS_INFO[technique][algorithm].description diff --git a/src/engines/config.py b/src/engines/config.py new file mode 100644 index 0000000..19f3bc4 --- /dev/null +++ b/src/engines/config.py @@ -0,0 +1 @@ +DB_NAME = "test1" diff --git a/src/engines/import_data_engine.py b/src/engines/import_data_engine.py new file mode 100644 index 0000000..a6e4a51 --- /dev/null +++ b/src/engines/import_data_engine.py @@ -0,0 +1,114 @@ +from typing import List, Optional + +import pandas as pd + +from data_import import Loader +from database import DatabaseObjectManager, Writer +from engines import DB_NAME +from state import State + + +class ImportDataEngine: + def __init__(self, state: State): + self.state = state + self.reader_data = None + self.from_file = False + self.database_manager = DatabaseObjectManager() + self.loader = Loader() + + def load_data_from_file(self, file_path: str) -> None: + try: + self.reader_data = self.loader.create_file_reader(file_path) + self.from_file = True + except ValueError as e: + self.reader_data = None + raise ValueError(e) + + def load_data_from_database(self, document_name: str) -> str: + try: + self.reader_data = self.loader.create_database_reader(document_name) + self.from_file = False + except ValueError as e: + self.reader_data = None + raise ValueError(e) + + def get_table_names_from_database(self) -> List[str]: + return self.database_manager.get_collections_list(DB_NAME) + + def is_data_big(self) -> bool: + return self.from_file and self.reader_data.is_file_big() + + def get_columns(self) -> List[str]: + return list(self.state.raw_data.columns) + + def clear_import(self): + self.reader_data = None + self.state.imported_data = None + self.state.raw_data = None + self.state.reduced_columns = [] + self.state.steps_visualization = None + self.state.algorithm_results_widgets = {} + self.state.last_algorithm = None + + def read_data(self, columns: Optional[List[str]] = None): + self.state.imported_data = self.reader_data.read(columns) + self.state.raw_data = self.state.imported_data.copy() + self.state.reduced_columns = [] + self.state.steps_visualization = None + self.state.algorithm_results_widgets = {} + self.state.last_algorithm = None + + def limit_data( + self, + columns: Optional[List[str]] = None, + limit_type: Optional[str] = None, + limit_num: Optional[str] = None, + ): + self.drop_additional_columns() + if columns is not None: + self.state.raw_data = self.state.raw_data[columns] + if limit_type is not None: + if limit_type == "first": + self.state.raw_data = self.state.raw_data.iloc[:limit_num] + elif limit_type == "random": + self.state.raw_data = self.state.raw_data.sample(limit_num).reset_index( + drop=True + ) + self.state.imported_data = self.state.raw_data.copy() + + def drop_additional_columns(self): + self.state.raw_data.drop(self.state.reduced_columns, axis=1, inplace=True) + self.state.imported_data = self.state.raw_data.copy() + self.state.reduced_columns = [] + self.state.steps_visualization = None + self.state.algorithm_results_widgets = {} + self.state.last_algorithm = None + + def save_to_database(self, title: str) -> str: + writer = Writer(DB_NAME, title) + try: + if type(self.state.raw_data) == pd.DataFrame: + writer.add_dataset(self.state.raw_data) + else: + for chunk in self.state.save_data: + writer.add_dataset(chunk) + except Exception as e: + print(e) + return "There is some problem with database." + result = self.load_data_from_database(title) + if result: + return result + self.read_data() + return "" + + def set_generated_data(self, data: pd.DataFrame): + self.state.raw_data = data + self.state.imported_data = data.copy() + + def merge_sets(self, new_data: pd.DataFrame) -> None: + self.state.imported_data = pd.concat( + [self.state.imported_data, new_data], ignore_index=True + ) + self.state.raw_data = pd.concat( + [self.state.raw_data, new_data], ignore_index=True + ) diff --git a/src/engines/preprocessing_engine.py b/src/engines/preprocessing_engine.py new file mode 100644 index 0000000..5ceccaf --- /dev/null +++ b/src/engines/preprocessing_engine.py @@ -0,0 +1,75 @@ +import numpy as np +from pandas.api.types import is_numeric_dtype + +from preprocess import DataCleaner, PCAReducer +from state import State + + +class PreprocessingEngine: + def __init__(self, state: State): + self.state = state + self.cleaner = DataCleaner(self.state) + self.reducer = PCAReducer(self.state) + + def get_raw_columns(self): + if self.state.raw_data is None: + return [] + return self.state.raw_data.columns + + def get_columns(self): + if self.state.imported_data is None: + return [] + return self.state.imported_data.columns + + def get_numeric_columns(self): + if self.state.imported_data is None: + return [] + return self.state.imported_data.select_dtypes(include=["number"]).columns + + def get_size(self): + if self.state.imported_data is None: + return 0 + return len(self.state.imported_data.select_dtypes(include=["number"])) + + def set_state(self, columns): + self.state.imported_data = self.state.raw_data[columns].copy() + + def clean_data(self, op_type): + match op_type: + case "cast": + self.cleaner.cast_nulls(np.NaN) + case "remove": + self.cleaner.remove_nulls() + + def has_rows_with_nulls(self, columns): + return self.state.raw_data[columns].isnull().values.any() + + def reduce_dimensions(self, dim_number=None): + return self.reducer.reduce(dim_number) + + def number_of_numeric_columns(self): + return len(self.get_numeric_columns()) + + def rename_column(self, index, new_header): + column = self.state.imported_data.columns[index] + self.state.imported_data.rename(columns={column: new_header}, inplace=True) + self.state.raw_data.rename(columns={column: new_header}, inplace=True) + try: + # omit on change not reduced column name + reduced_arr_idx = self.state.reduced_columns.index(column) + self.state.reduced_columns[reduced_arr_idx] = new_header + except ValueError: + pass + + def mean_or_mode_estimate(self): + missing_data_columns = self.get_columns()[ + self.state.imported_data.isna().any() + ].to_list() + for header in missing_data_columns: + column = self.state.imported_data.loc[:, header] + column_type = column.dtypes + new_value = ( + column.mean() if is_numeric_dtype(column_type) else column.mode()[0] + ) + if new_value is not None: + column.fillna(new_value, inplace=True) diff --git a/src/engines/results_engine.py b/src/engines/results_engine.py new file mode 100644 index 0000000..ae1da52 --- /dev/null +++ b/src/engines/results_engine.py @@ -0,0 +1,6 @@ +from state import State + + +class ResultsEngine: + def __init__(self, state: State): + self.state = state diff --git a/src/preprocess/__init__.py b/src/preprocess/__init__.py index e69de29..4a56fce 100644 --- a/src/preprocess/__init__.py +++ b/src/preprocess/__init__.py @@ -0,0 +1,2 @@ +from .cleaning import DataCleaner +from .reduction import PCAReducer diff --git a/src/preprocess/cleaning/__init__.py b/src/preprocess/cleaning/__init__.py index e69de29..fdfb78f 100644 --- a/src/preprocess/cleaning/__init__.py +++ b/src/preprocess/cleaning/__init__.py @@ -0,0 +1 @@ +from .clean_data import DataCleaner diff --git a/src/preprocess/cleaning/clean_data.py b/src/preprocess/cleaning/clean_data.py new file mode 100644 index 0000000..d17e701 --- /dev/null +++ b/src/preprocess/cleaning/clean_data.py @@ -0,0 +1,12 @@ +from state import State + + +class DataCleaner: + def __init__(self, state: State): + self.state = state + + def cast_nulls(self, value): + self.state.imported_data = self.state.imported_data.fillna(value) + + def remove_nulls(self): + self.state.imported_data.dropna(inplace=True) diff --git a/src/preprocess/reduction/__init__.py b/src/preprocess/reduction/__init__.py index e69de29..348c94e 100644 --- a/src/preprocess/reduction/__init__.py +++ b/src/preprocess/reduction/__init__.py @@ -0,0 +1 @@ +from .pca import PCAReducer diff --git a/src/preprocess/reduction/pca.py b/src/preprocess/reduction/pca.py new file mode 100644 index 0000000..cae5ee9 --- /dev/null +++ b/src/preprocess/reduction/pca.py @@ -0,0 +1,80 @@ +from typing import Optional + +import numpy as np +import pandas as pd + +from state import State + + +class PCAReducer: + def __init__(self, state: State, acceptable_ratio=0.05): + self.state = state + self.acceptable_ratio = acceptable_ratio + self.initial_columns = [] + + def reduce(self, dim_number: Optional[int]) -> list[str]: + data = self.state.imported_data.select_dtypes(include=np.number) + data = data - data.mean() + self.initial_columns = list(data.columns) + covariance_matrix = data.cov() + reduce_matrix = self._pca(covariance_matrix, dim_number) + override = np.dot(data, reduce_matrix) + columns = [ + "{}".format(self.format_column_name(reduce_matrix[:, i])) + for i in range(dim_number or override.shape[1]) + ] + self.state.imported_data = pd.DataFrame(override, columns=columns) + self.state.raw_data = pd.concat( + [self.state.raw_data, pd.DataFrame(override, columns=columns)], axis=1 + ) + return columns + + def _pca(self, matrix, dim_number=None): + reducer, weights = self._svd(matrix, k=dim_number) + if dim_number is None: + total = sum(weights) + ratios = [weight / total for weight in weights] + columns_num = max( + len(list(filter(lambda x: x > self.acceptable_ratio, ratios))), 2 + ) + return reducer[:, :columns_num] + return reducer + + @staticmethod + def _dominant_component(matrix, eps): + m, n = matrix.shape + v = np.ones(n) / np.sqrt(n) + while True: + new_v = np.dot(matrix, v) + new_v = new_v / np.linalg.norm(new_v) + if abs(np.linalg.norm(v - new_v)) < eps: + return v + v = new_v + + def _svd(self, matrix, k=None, eps=1e-10): + matrix_helper = np.array(matrix) + m, n = matrix_helper.shape + svd_components = [] + if k is None: + k = n + for i in range(k): + v = self._dominant_component(matrix_helper, eps=eps) + u_unnormalized = np.dot(matrix, v) + sigma = np.linalg.norm(u_unnormalized) + u = u_unnormalized / sigma + + matrix_helper -= sigma * np.outer(u, v) + + svd_components.append((u, sigma)) + + U, Sigma = [np.array(x) for x in zip(*svd_components)] + return U.T, Sigma + + def format_column_name(self, vector): + label = "" + indexes = np.argpartition(vector, -2)[-2:] + for index in indexes: # arbitrary value + label += "{}*{}+".format( + round(vector[index], 2), self.initial_columns[index] + ) + return label.rstrip("+") diff --git a/src/state.py b/src/state.py new file mode 100644 index 0000000..d90278a --- /dev/null +++ b/src/state.py @@ -0,0 +1,13 @@ +from PyQt5.QtWidgets import QWidget + + +class State: + algorithm_results_widgets: dict[str, dict[str, list[QWidget]]] + + def __init__(self): + self.raw_data = None + self.imported_data = None + self.steps_visualization = None + self.algorithm_results_widgets = {} + self.reduced_columns = [] + self.last_algorithm = None diff --git a/src/utils/QImage.py b/src/utils/QImage.py new file mode 100644 index 0000000..8d551d1 --- /dev/null +++ b/src/utils/QImage.py @@ -0,0 +1,26 @@ +from PyQt5.QtCore import QRect +from PyQt5.QtGui import QPainter, QPaintEvent, QPixmap +from PyQt5.QtWidgets import QWidget + + +class QImage(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.p = QPixmap() + + def setPixmap(self, p: QPixmap): + self.p = p + self.update() + + def paintEvent(self, event: QPaintEvent) -> None: + if not self.p.isNull(): + painter = QPainter(self) + painter.setRenderHint(QPainter.SmoothPixmapTransform) + _, _, w_widget, h_widget = self.rect().getRect() + x, y, w, h = self.p.rect().getRect() + if w > w_widget or h > h_widget: + alfa = min(w_widget / w, h_widget / h) + w = int(alfa * w) + h = int(alfa * h) + x = int(0.5 * (w_widget - w)) + painter.drawPixmap(QRect(x, y, w, h), self.p) diff --git a/src/utils/QtImageViewer.py b/src/utils/QtImageViewer.py new file mode 100644 index 0000000..c4b4faa --- /dev/null +++ b/src/utils/QtImageViewer.py @@ -0,0 +1,686 @@ +""" QtImageViewer.py: PyQt image viewer widget based on QGraphicsView with mouse zooming/panning and ROIs. +""" + +import os.path + +try: + from PyQt6.QtCore import QEvent, QPoint, QPointF, QRectF, QSize, Qt, pyqtSignal + from PyQt6.QtGui import QImage, QMouseEvent, QPainter, QPainterPath, QPen, QPixmap + from PyQt6.QtWidgets import ( + QFileDialog, + QGraphicsEllipseItem, + QGraphicsItem, + QGraphicsLineItem, + QGraphicsPolygonItem, + QGraphicsRectItem, + QGraphicsScene, + QGraphicsView, + QSizePolicy, + ) +except ImportError: + try: + from PyQt5.QtCore import QEvent, QPoint, QPointF, QRectF, QSize, Qt, pyqtSignal + from PyQt5.QtGui import ( + QImage, + QMouseEvent, + QPainter, + QPainterPath, + QPen, + QPixmap, + ) + from PyQt5.QtWidgets import ( + QFileDialog, + QGraphicsEllipseItem, + QGraphicsItem, + QGraphicsLineItem, + QGraphicsPolygonItem, + QGraphicsRectItem, + QGraphicsScene, + QGraphicsView, + QSizePolicy, + ) + except ImportError: + raise ImportError("Requires PyQt (version 5 or 6)") + +# numpy is optional: only needed if you want to display numpy 2d arrays as images. +try: + import numpy as np +except ImportError: + np = None + +# qimage2ndarray is optional: useful for displaying numpy 2d arrays as images. +# !!! qimage2ndarray requires PyQt5. +# Some custom code in the viewer appears to handle the conversion from numpy 2d arrays, +# so qimage2ndarray probably is not needed anymore. I've left it here just in case. +try: + import qimage2ndarray +except ImportError: + qimage2ndarray = None + +__author__ = "Marcel Goldschen-Ohm " +__version__ = "2.0.0" + + +class QtImageViewer(QGraphicsView): + """PyQt image viewer widget based on QGraphicsView with mouse zooming/panning and ROIs. + Image File: + ----------- + Use the open("path/to/file") method to load an image file into the viewer. + Calling open() without a file argument will popup a file selection dialog. + Image: + ------ + Use the setImage(im) method to set the image data in the viewer. + - im can be a QImage, QPixmap, or NumPy 2D array (the later requires the package qimage2ndarray). + For display in the QGraphicsView the image will be converted to a QPixmap. + Some useful image format conversion utilities: + qimage2ndarray: NumPy ndarray <==> QImage (https://github.com/hmeine/qimage2ndarray) + ImageQt: PIL Image <==> QImage (https://github.com/python-pillow/Pillow/blob/master/PIL/ImageQt.py) + Mouse: + ------ + Mouse interactions for zooming and panning is fully customizable by simply setting the desired button interactions: + e.g., + regionZoomButton = Qt.LeftButton # Drag a zoom box. + zoomOutButton = Qt.RightButton # Pop end of zoom stack (double click clears zoom stack). + panButton = Qt.MiddleButton # Drag to pan. + wheelZoomFactor = 1.25 # Set to None or 1 to disable mouse wheel zoom. + To disable any interaction, just disable its button. + e.g., to disable panning: + panButton = None + ROIs: + ----- + Can also add ellipse, rectangle, line, and polygon ROIs to the image. + ROIs should be derived from the provided EllipseROI, RectROI, LineROI, and PolygonROI classes. + ROIs are selectable and optionally moveable with the mouse (see setROIsAreMovable). + TODO: Add support for editing the displayed image contrast. + TODO: Add support for drawing ROIs with the mouse. + """ + + # Mouse button signals emit image scene (x, y) coordinates. + # !!! For image (row, column) matrix indexing, row = y and column = x. + # !!! These signals will NOT be emitted if the event is handled by an interaction such as zoom or pan. + # !!! If aspect ratio prevents image from filling viewport, emitted position may be outside image bounds. + leftMouseButtonPressed = pyqtSignal(float, float) + leftMouseButtonReleased = pyqtSignal(float, float) + middleMouseButtonPressed = pyqtSignal(float, float) + middleMouseButtonReleased = pyqtSignal(float, float) + rightMouseButtonPressed = pyqtSignal(float, float) + rightMouseButtonReleased = pyqtSignal(float, float) + leftMouseButtonDoubleClicked = pyqtSignal(float, float) + rightMouseButtonDoubleClicked = pyqtSignal(float, float) + + # Emitted upon zooming/panning. + viewChanged = pyqtSignal() + + # Emitted on mouse motion. + # Emits mouse position over image in image pixel coordinates. + # !!! setMouseTracking(True) if you want to use this at all times. + mousePositionOnImageChanged = pyqtSignal(QPoint) + + # Emit index of selected ROI + roiSelected = pyqtSignal(int) + + def __init__(self): + QGraphicsView.__init__(self) + + # Image is displayed as a QPixmap in a QGraphicsScene attached to this QGraphicsView. + self.scene = QGraphicsScene() + self.setScene(self.scene) + + # Better quality pixmap scaling? + # self.setRenderHints(QPainter.Antialiasing | QPainter.SmoothPixmapTransform) + + # Displayed image pixmap in the QGraphicsScene. + self._image = None + + # Image aspect ratio mode. + # Qt.IgnoreAspectRatio: Scale image to fit viewport. + # Qt.KeepAspectRatio: Scale image to fit inside viewport, preserving aspect ratio. + # Qt.KeepAspectRatioByExpanding: Scale image to fill the viewport, preserving aspect ratio. + self.aspectRatioMode = Qt.AspectRatioMode.KeepAspectRatio + + # Scroll bar behaviour. + # Qt.ScrollBarAlwaysOff: Never shows a scroll bar. + # Qt.ScrollBarAlwaysOn: Always shows a scroll bar. + # Qt.ScrollBarAsNeeded: Shows a scroll bar only when zoomed. + self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + + # Interactions (set buttons to None to disable interactions) + # !!! Events handled by interactions will NOT emit *MouseButton* signals. + # Note: regionZoomButton will still emit a *MouseButtonReleased signal on a click (i.e. tiny box). + self.regionZoomButton = Qt.MouseButton.LeftButton # Drag a zoom box. + self.zoomOutButton = ( + Qt.MouseButton.RightButton + ) # Pop end of zoom stack (double click clears zoom stack). + self.panButton = Qt.MouseButton.MiddleButton # Drag to pan. + self.wheelZoomFactor = 1.25 # Set to None or 1 to disable mouse wheel zoom. + + # Stack of QRectF zoom boxes in scene coordinates. + # !!! If you update this manually, be sure to call updateViewer() to reflect any changes. + self.zoomStack = [] + + # Flags for active zooming/panning. + self._isZooming = False + self._isPanning = False + + # Store temporary position in screen pixels or scene units. + self._pixelPosition = QPoint() + self._scenePosition = QPointF() + + # Track mouse position. e.g., For displaying coordinates in a UI. + # self.setMouseTracking(True) + + # ROIs. + self.ROIs = [] + + # # For drawing ROIs. + # self.drawROI = None + + self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + def sizeHint(self): + return QSize(900, 600) + + def hasImage(self): + """Returns whether the scene contains an image pixmap.""" + return self._image is not None + + def clearImage(self): + """Removes the current image pixmap from the scene if it exists.""" + if self.hasImage(): + self.scene.removeItem(self._image) + self._image = None + + def pixmap(self): + """Returns the scene's current image pixmap as a QPixmap, or else None if no image exists. + :rtype: QPixmap | None + """ + if self.hasImage(): + return self._image.pixmap() + return None + + def image(self): + """Returns the scene's current image pixmap as a QImage, or else None if no image exists. + :rtype: QImage | None + """ + if self.hasImage(): + return self._image.pixmap().toImage() + return None + + def setImage(self, image): + """Set the scene's current image pixmap to the input QImage or QPixmap. + Raises a RuntimeError if the input image has type other than QImage or QPixmap. + :type image: QImage | QPixmap + """ + if type(image) is QPixmap: + pixmap = image + elif type(image) is QImage: + pixmap = QPixmap.fromImage(image) + elif (np is not None) and (type(image) is np.ndarray): + if qimage2ndarray is not None: + qimage = qimage2ndarray.array2qimage(image, True) + pixmap = QPixmap.fromImage(qimage) + else: + image = image.astype(np.float32) + image -= image.min() + image /= image.max() + image *= 255 + image[image > 255] = 255 + image[image < 0] = 0 + image = image.astype(np.uint8) + height, width = image.shape + bytes = image.tobytes() + qimage = QImage(bytes, width, height, QImage.Format.Format_Grayscale8) + pixmap = QPixmap.fromImage(qimage) + else: + raise RuntimeError( + "ImageViewer.setImage: Argument must be a QImage, QPixmap, or numpy.ndarray." + ) + if self.hasImage(): + self._image.setPixmap(pixmap) + else: + self._image = self.scene.addPixmap(pixmap) + + # Better quality pixmap scaling? + # !!! This will distort actual pixel data when zoomed way in. + # For scientific image analysis, you probably don't want this. + # self._pixmap.setTransformationMode(Qt.SmoothTransformation) + + self.setSceneRect(QRectF(pixmap.rect())) # Set scene size to image size. + self.updateViewer() + + def open(self, filepath=None): + """Load an image from file. + Without any arguments, loadImageFromFile() will pop up a file dialog to choose the image file. + With a fileName argument, loadImageFromFile(fileName) will attempt to load the specified image file directly. + """ + if filepath is None: + filepath, dummy = QFileDialog.getOpenFileName(self, "Open image file.") + if len(filepath) and os.path.isfile(filepath): + image = QImage(filepath) + self.setImage(image) + + def updateViewer(self): + """Show current zoom (if showing entire image, apply current aspect ratio mode).""" + if not self.hasImage(): + return + if len(self.zoomStack): + self.fitInView( + self.zoomStack[-1], self.aspectRatioMode + ) # Show zoomed rect. + # else: + # self.fitInView(self.sceneRect(), self.aspectRatioMode) # Show entire image. + + def clearZoom(self): + if len(self.zoomStack) > 0: + self.zoomStack = [] + self.updateViewer() + self.viewChanged.emit() + + def resizeEvent(self, event): + """Maintain current zoom on resize.""" + self.updateViewer() + + def mousePressEvent(self, event): + """Start mouse pan or zoom mode.""" + # Ignore dummy events. e.g., Faking pan with left button ScrollHandDrag. + dummyModifiers = Qt.KeyboardModifier( + Qt.KeyboardModifier.ShiftModifier + | Qt.KeyboardModifier.ControlModifier + | Qt.KeyboardModifier.AltModifier + | Qt.KeyboardModifier.MetaModifier + ) + if event.modifiers() == dummyModifiers: + QGraphicsView.mousePressEvent(self, event) + event.accept() + return + + # # Draw ROI + # if self.drawROI is not None: + # if self.drawROI == "Ellipse": + # # Click and drag to draw ellipse. +Shift for circle. + # pass + # elif self.drawROI == "Rect": + # # Click and drag to draw rectangle. +Shift for square. + # pass + # elif self.drawROI == "Line": + # # Click and drag to draw line. + # pass + # elif self.drawROI == "Polygon": + # # Click to add points to polygon. Double-click to close polygon. + # pass + + # Start dragging a region zoom box? + if (self.regionZoomButton is not None) and ( + event.button() == self.regionZoomButton + ): + self._pixelPosition = event.pos() # store pixel position + self.setDragMode(QGraphicsView.DragMode.RubberBandDrag) + QGraphicsView.mousePressEvent(self, event) + event.accept() + self._isZooming = True + return + + if (self.zoomOutButton is not None) and (event.button() == self.zoomOutButton): + if len(self.zoomStack): + self.zoomStack.pop() + self.updateViewer() + self.viewChanged.emit() + event.accept() + return + + # Start dragging to pan? + if (self.panButton is not None) and (event.button() == self.panButton): + self._pixelPosition = event.pos() # store pixel position + self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag) + if self.panButton == Qt.MouseButton.LeftButton: + QGraphicsView.mousePressEvent(self, event) + else: + # ScrollHandDrag ONLY works with LeftButton, so fake it. + # Use a bunch of dummy modifiers to notify that event should NOT be handled as usual. + self.viewport().setCursor(Qt.CursorShape.ClosedHandCursor) + dummyModifiers = Qt.KeyboardModifier( + Qt.KeyboardModifier.ShiftModifier + | Qt.KeyboardModifier.ControlModifier + | Qt.KeyboardModifier.AltModifier + | Qt.KeyboardModifier.MetaModifier + ) + dummyEvent = QMouseEvent( + QEvent.Type.MouseButtonPress, + QPointF(event.pos()), + Qt.MouseButton.LeftButton, + event.buttons(), + dummyModifiers, + ) + self.mousePressEvent(dummyEvent) + sceneViewport = ( + self.mapToScene(self.viewport().rect()) + .boundingRect() + .intersected(self.sceneRect()) + ) + self._scenePosition = sceneViewport.topLeft() + event.accept() + self._isPanning = True + return + + scenePos = self.mapToScene(event.pos()) + if event.button() == Qt.MouseButton.LeftButton: + self.leftMouseButtonPressed.emit(scenePos.x(), scenePos.y()) + elif event.button() == Qt.MouseButton.MiddleButton: + self.middleMouseButtonPressed.emit(scenePos.x(), scenePos.y()) + elif event.button() == Qt.MouseButton.RightButton: + self.rightMouseButtonPressed.emit(scenePos.x(), scenePos.y()) + + QGraphicsView.mousePressEvent(self, event) + + def mouseReleaseEvent(self, event): + """Stop mouse pan or zoom mode (apply zoom if valid).""" + # Ignore dummy events. e.g., Faking pan with left button ScrollHandDrag. + dummyModifiers = Qt.KeyboardModifier( + Qt.KeyboardModifier.ShiftModifier + | Qt.KeyboardModifier.ControlModifier + | Qt.KeyboardModifier.AltModifier + | Qt.KeyboardModifier.MetaModifier + ) + if event.modifiers() == dummyModifiers: + QGraphicsView.mouseReleaseEvent(self, event) + event.accept() + return + + # Finish dragging a region zoom box? + if (self.regionZoomButton is not None) and ( + event.button() == self.regionZoomButton + ): + QGraphicsView.mouseReleaseEvent(self, event) + zoomRect = ( + self.scene.selectionArea().boundingRect().intersected(self.sceneRect()) + ) + # Clear current selection area (i.e. rubberband rect). + self.scene.setSelectionArea(QPainterPath()) + self.setDragMode(QGraphicsView.DragMode.NoDrag) + # If zoom box is 3x3 screen pixels or smaller, do not zoom and proceed to process as a click release. + zoomPixelWidth = abs(event.pos().x() - self._pixelPosition.x()) + zoomPixelHeight = abs(event.pos().y() - self._pixelPosition.y()) + if zoomPixelWidth > 3 and zoomPixelHeight > 3: + if zoomRect.isValid() and (zoomRect != self.sceneRect()): + self.zoomStack.append(zoomRect) + self.updateViewer() + self.viewChanged.emit() + event.accept() + self._isZooming = False + return + + # Finish panning? + if (self.panButton is not None) and (event.button() == self.panButton): + if self.panButton == Qt.MouseButton.LeftButton: + QGraphicsView.mouseReleaseEvent(self, event) + else: + # ScrollHandDrag ONLY works with LeftButton, so fake it. + # Use a bunch of dummy modifiers to notify that event should NOT be handled as usual. + self.viewport().setCursor(Qt.CursorShape.ArrowCursor) + dummyModifiers = Qt.KeyboardModifier( + Qt.KeyboardModifier.ShiftModifier + | Qt.KeyboardModifier.ControlModifier + | Qt.KeyboardModifier.AltModifier + | Qt.KeyboardModifier.MetaModifier + ) + dummyEvent = QMouseEvent( + QEvent.Type.MouseButtonRelease, + QPointF(event.pos()), + Qt.MouseButton.LeftButton, + event.buttons(), + dummyModifiers, + ) + self.mouseReleaseEvent(dummyEvent) + self.setDragMode(QGraphicsView.DragMode.NoDrag) + if len(self.zoomStack) > 0: + sceneViewport = ( + self.mapToScene(self.viewport().rect()) + .boundingRect() + .intersected(self.sceneRect()) + ) + delta = sceneViewport.topLeft() - self._scenePosition + self.zoomStack[-1].translate(delta) + self.zoomStack[-1] = self.zoomStack[-1].intersected(self.sceneRect()) + self.viewChanged.emit() + event.accept() + self._isPanning = False + return + + scenePos = self.mapToScene(event.pos()) + if event.button() == Qt.MouseButton.LeftButton: + self.leftMouseButtonReleased.emit(scenePos.x(), scenePos.y()) + elif event.button() == Qt.MouseButton.MiddleButton: + self.middleMouseButtonReleased.emit(scenePos.x(), scenePos.y()) + elif event.button() == Qt.MouseButton.RightButton: + self.rightMouseButtonReleased.emit(scenePos.x(), scenePos.y()) + + QGraphicsView.mouseReleaseEvent(self, event) + + def mouseDoubleClickEvent(self, event): + """Show entire image.""" + # Zoom out on double click? + if (self.zoomOutButton is not None) and (event.button() == self.zoomOutButton): + self.clearZoom() + event.accept() + return + + scenePos = self.mapToScene(event.pos()) + if event.button() == Qt.MouseButton.LeftButton: + self.leftMouseButtonDoubleClicked.emit(scenePos.x(), scenePos.y()) + elif event.button() == Qt.MouseButton.RightButton: + self.rightMouseButtonDoubleClicked.emit(scenePos.x(), scenePos.y()) + + QGraphicsView.mouseDoubleClickEvent(self, event) + + def wheelEvent(self, event): + if self.wheelZoomFactor is not None: + if self.wheelZoomFactor == 1: + return + if event.angleDelta().y() < 0: + # zoom in + if len(self.zoomStack) == 0: + self.zoomStack.append(self.sceneRect()) + elif len(self.zoomStack) > 1: + del self.zoomStack[:-1] + zoomRect = self.zoomStack[-1] + center = zoomRect.center() + zoomRect.setWidth(zoomRect.width() / self.wheelZoomFactor) + zoomRect.setHeight(zoomRect.height() / self.wheelZoomFactor) + zoomRect.moveCenter(center) + self.zoomStack[-1] = zoomRect.intersected(self.sceneRect()) + self.updateViewer() + self.viewChanged.emit() + else: + # zoom out + if len(self.zoomStack) == 0: + # Already fully zoomed out. + return + if len(self.zoomStack) > 1: + del self.zoomStack[:-1] + zoomRect = self.zoomStack[-1] + center = zoomRect.center() + zoomRect.setWidth(zoomRect.width() * self.wheelZoomFactor) + zoomRect.setHeight(zoomRect.height() * self.wheelZoomFactor) + zoomRect.moveCenter(center) + self.zoomStack[-1] = zoomRect.intersected(self.sceneRect()) + if self.zoomStack[-1] == self.sceneRect(): + self.zoomStack = [] + self.updateViewer() + self.viewChanged.emit() + event.accept() + return + + QGraphicsView.wheelEvent(self, event) + + def mouseMoveEvent(self, event): + # Emit updated view during panning. + if self._isPanning: + QGraphicsView.mouseMoveEvent(self, event) + if len(self.zoomStack) > 0: + sceneViewport = ( + self.mapToScene(self.viewport().rect()) + .boundingRect() + .intersected(self.sceneRect()) + ) + delta = sceneViewport.topLeft() - self._scenePosition + self._scenePosition = sceneViewport.topLeft() + self.zoomStack[-1].translate(delta) + self.zoomStack[-1] = self.zoomStack[-1].intersected(self.sceneRect()) + self.updateViewer() + self.viewChanged.emit() + + scenePos = self.mapToScene(event.pos()) + if self.sceneRect().contains(scenePos): + # Pixel index offset from pixel center. + x = int(round(scenePos.x() - 0.5)) + y = int(round(scenePos.y() - 0.5)) + imagePos = QPoint(x, y) + else: + # Invalid pixel position. + imagePos = QPoint(-1, -1) + self.mousePositionOnImageChanged.emit(imagePos) + + QGraphicsView.mouseMoveEvent(self, event) + + def enterEvent(self, event): + self.setCursor(Qt.CursorShape.CrossCursor) + + def leaveEvent(self, event): + self.setCursor(Qt.CursorShape.ArrowCursor) + + def addROIs(self, rois): + for roi in rois: + self.scene.addItem(roi) + self.ROIs.append(roi) + + def deleteROIs(self, rois): + for roi in rois: + self.scene.removeItem(roi) + self.ROIs.remove(roi) + del roi + + def clearROIs(self): + for roi in self.ROIs: + self.scene.removeItem(roi) + del self.ROIs[:] + + def roiClicked(self, roi): + for i in range(len(self.ROIs)): + if roi is self.ROIs[i]: + self.roiSelected.emit(i) + print(i) + break + + def setROIsAreMovable(self, tf): + if tf: + for roi in self.ROIs: + roi.setFlags(roi.flags() | QGraphicsItem.GraphicsItemFlag.ItemIsMovable) + else: + for roi in self.ROIs: + roi.setFlags( + roi.flags() & ~QGraphicsItem.GraphicsItemFlag.ItemIsMovable + ) + + def addSpots(self, xy, radius): + for xy_ in xy: + x, y = xy_ + spot = EllipseROI(self) + spot.setRect(x - radius, y - radius, 2 * radius, 2 * radius) + self.scene.addItem(spot) + self.ROIs.append(spot) + + +class EllipseROI(QGraphicsEllipseItem): + def __init__(self, viewer): + QGraphicsItem.__init__(self) + self._viewer = viewer + pen = QPen(Qt.yellow) + pen.setCosmetic(True) + self.setPen(pen) + self.setFlags(self.GraphicsItemFlag.ItemIsSelectable) + + def mousePressEvent(self, event): + QGraphicsItem.mousePressEvent(self, event) + if event.button() == Qt.MouseButton.LeftButton: + self._viewer.roiClicked(self) + + +class RectROI(QGraphicsRectItem): + def __init__(self, viewer): + QGraphicsItem.__init__(self) + self._viewer = viewer + pen = QPen(Qt.GlobalColor.yellow) + pen.setCosmetic(True) + self.setPen(pen) + self.setFlags(self.GraphicsItemFlag.ItemIsSelectable) + + def mousePressEvent(self, event): + QGraphicsItem.mousePressEvent(self, event) + if event.button() == Qt.MouseButton.LeftButton: + self._viewer.roiClicked(self) + + +class LineROI(QGraphicsLineItem): + def __init__(self, viewer): + QGraphicsItem.__init__(self) + self._viewer = viewer + pen = QPen(Qt.GlobalColor.yellow) + pen.setCosmetic(True) + self.setPen(pen) + self.setFlags(self.GraphicsItemFlag.ItemIsSelectable) + + def mousePressEvent(self, event): + QGraphicsItem.mousePressEvent(self, event) + if event.button() == Qt.MouseButton.LeftButton: + self._viewer.roiClicked(self) + + +class PolygonROI(QGraphicsPolygonItem): + def __init__(self, viewer): + QGraphicsItem.__init__(self) + self._viewer = viewer + pen = QPen(Qt.GlobalColor.yellow) + pen.setCosmetic(True) + self.setPen(pen) + self.setFlags(self.GraphicsItemFlag.ItemIsSelectable) + + def mousePressEvent(self, event): + QGraphicsItem.mousePressEvent(self, event) + if event.button() == Qt.MouseButton.LeftButton: + self._viewer.roiClicked(self) + + +if __name__ == "__main__": + import sys + + try: + from PyQt6.QtWidgets import QApplication + except ImportError: + from PyQt5.QtWidgets import QApplication + + def handleLeftClick(x, y): + row = int(y) + column = int(x) + print( + "Clicked on image pixel (row=" + str(row) + ", column=" + str(column) + ")" + ) + + def handleViewChange(): + print("viewChanged") + + # Create the application. + app = QApplication(sys.argv) + + # Create image viewer. + viewer = QtImageViewer() + + # Open an image from file. + viewer.open() + + # Handle left mouse clicks with custom slot. + viewer.leftMouseButtonReleased.connect(handleLeftClick) + + # Show viewer and run application. + viewer.show() + sys.exit(app.exec()) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..0b21796 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,7 @@ +from .automate_steps import AutomateSteps +from .QImage import QImage +from .QtImageViewer import QtImageViewer + + +def format_set(set_: tuple): + return f"{{{', '.join(set_)}}}" diff --git a/src/utils/automate_steps.py b/src/utils/automate_steps.py new file mode 100644 index 0000000..ff47c9b --- /dev/null +++ b/src/utils/automate_steps.py @@ -0,0 +1,45 @@ +from time import sleep +from typing import Callable + +from PyQt5.QtCore import QThread, pyqtSignal + + +class Runner(QThread): + signal = pyqtSignal() + + def __init__(self): + super().__init__() + self.is_running = False + self.step_time = 10 + + def run(self): + while self.is_running: + self.signal.emit() + sleep(self.step_time / 1000) + + +class AutomateSteps: + def __init__(self, to_execute: Callable, when_restart: Callable): + self.to_execute = to_execute + self.when_restart = when_restart + self.thread = None + self.step_time = 10 + + def set_time(self, time: int): + self.step_time = time + + def resume(self): + self.thread = Runner() + self.thread.signal.connect(self.to_execute) + self.thread.is_running = True + self.thread.step_time = self.step_time + self.thread.start() + + def pause(self): + if self.thread: + self.thread.is_running = False + self.thread = None + + def restart(self): + self.pause() + self.when_restart() diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py index e69de29..59de6ae 100644 --- a/src/visualization/__init__.py +++ b/src/visualization/__init__.py @@ -0,0 +1,2 @@ +from .a_priori_canvas import APrioriGauge, APrioriGraphPlot, APrioriScatterPlot +from .clustering_canvas import ClusteringCanvas diff --git a/src/visualization/a_priori_canvas.py b/src/visualization/a_priori_canvas.py new file mode 100644 index 0000000..f2d97e9 --- /dev/null +++ b/src/visualization/a_priori_canvas.py @@ -0,0 +1,291 @@ +import operator +from itertools import compress +from random import random + +import matplotlib.pyplot as plt +import networkx as nx +import plotly.graph_objects as go +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from PyQt5.QtWebEngineWidgets import QWebEngineView +from PyQt5.QtWidgets import QVBoxLayout, QWidget + + +class APrioriScatterPlot(FigureCanvasQTAgg): + HIGHLIGHTED_COLOR = "#f17300" + REGULAR_COLOR = "#F5F5F5" + HIGHLIGHTED_RULE_A_COLOR = "#81a4cd" + HIGHLIGHTED_RULE_B_COLOR = "#054a91" + HIGHLIGHTED_RULE_A_B_COLOR = "#f17300" + + def __init__(self, transaction_sets): + self.annot = None + self.sc = None + fig, self.axes = plt.subplots() + fig.tight_layout() + self.transaction_sets = transaction_sets + self._assign_random_positions() + + super().__init__(fig) + + def plot_set(self, highlighted_set: tuple = None): + self.axes.cla() + self.axes.axis("off") + + self.annot = self.axes.annotate( + "", + xy=(0, 0), + xytext=(-20, 0), + textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + ) + self.annot.set_visible(True) + + mask = self._get_points_mask_set( + set(highlighted_set) if highlighted_set is not None else None + ) + self.sc = self.axes.scatter( + list(compress(self.x_positions, mask)), + list(compress(self.y_positions, mask)), + color=self.HIGHLIGHTED_COLOR, + s=10, + zorder=2, + label="containing the frequent set", + ) + + inverse_mask = list(map(operator.not_, mask)) + self.axes.scatter( + list(compress(self.x_positions, inverse_mask)), + list(compress(self.y_positions, inverse_mask)), + color=self.REGULAR_COLOR, + s=10, + zorder=1, + label="not containing", + ) + + self.axes.legend() + self.figure.canvas.mpl_connect("motion_notify_event", self.hover) + self.draw() + + def plot_rule(self, set_a: tuple, set_b: tuple): + self.axes.cla() + self.axes.axis("off") + + mask_a, mask_b, mask_a_b, mask_remaining = self._get_points_mask_rule( + set(set_a), set(set_b) + ) + self.axes.scatter( + list(compress(self.x_positions, mask_a)), + list(compress(self.y_positions, mask_a)), + color=self.HIGHLIGHTED_RULE_A_COLOR, + s=10, + zorder=3, + label="A", + ) + + self.axes.scatter( + list(compress(self.x_positions, mask_b)), + list(compress(self.y_positions, mask_b)), + color=self.HIGHLIGHTED_RULE_B_COLOR, + s=10, + zorder=2, + label="B", + ) + + self.axes.scatter( + list(compress(self.x_positions, mask_a_b)), + list(compress(self.y_positions, mask_a_b)), + color=self.HIGHLIGHTED_RULE_A_B_COLOR, + s=10, + zorder=4, + label="A and B", + ) + + self.axes.scatter( + list(compress(self.x_positions, mask_remaining)), + list(compress(self.y_positions, mask_remaining)), + color=self.REGULAR_COLOR, + s=10, + zorder=1, + label="~(A or B)", + ) + + self.axes.legend() + self.draw() + + def _assign_random_positions(self): + self.x_positions = [random() for _ in self.transaction_sets] + self.y_positions = [random() for _ in self.transaction_sets] + + def _get_points_mask_set(self, highlighted_set: set = None): + if highlighted_set is None: + return [False] * len(self.transaction_sets) + + return [ + set(highlighted_set).issubset(transaction_set) + for transaction_set in self.transaction_sets + ] + + def _get_points_mask_rule(self, set_a: set, set_b: set): + return ( + [ + set_a.issubset(transaction_set) and not set_b.issubset(transaction_set) + for transaction_set in self.transaction_sets + ], + [ + set_b.issubset(transaction_set) and not set_a.issubset(transaction_set) + for transaction_set in self.transaction_sets + ], + [ + set_a.issubset(transaction_set) and set_b.issubset(transaction_set) + for transaction_set in self.transaction_sets + ], + [ + not set_a.issubset(transaction_set) + and not set_b.issubset(transaction_set) + for transaction_set in self.transaction_sets + ], + ) + + def update_annot(self, ind): + pos = self.sc.get_offsets()[ind["ind"][0]] + self.annot.xy = pos + text = "\n---\n".join("\n".join(self.transaction_sets[n]) for n in ind["ind"]) + self.annot.set_text(text) + self.annot.get_bbox_patch().set_alpha(0.4) + + def hover(self, event): + vis = self.annot.get_visible() + if event.inaxes == self.axes: + cont, ind = self.sc.contains(event) + if cont: + self.update_annot(ind) + self.annot.set_visible(True) + self.figure.canvas.draw_idle() + else: + if vis: + self.annot.set_visible(False) + self.figure.canvas.draw_idle() + + def reset(self): + self.axes.cla() + self.axes.axis("off") + + +class APrioriGauge(QWidget): + BLUE = "#054a91" + ORANGE = "#f17300" + + def __init__(self): + super().__init__() + self.browser = QWebEngineView(self) + layout = QVBoxLayout(self) + layout.addWidget(self.browser) + self.resize(1000, 1000) + + def plot_value(self, metric, threshold, metric_name=""): + fig = go.Figure( + go.Indicator( + mode="gauge+number", + value=metric, + domain={"x": [0, 1], "y": [0, 1]}, + title={"text": metric_name, "font": {"size": 15}}, + gauge=dict( + axis={"range": [None, 1], "tickwidth": 1}, + bar={"color": (self.BLUE if metric >= threshold else self.ORANGE)}, + bgcolor="#dbe4ee", + borderwidth=0, + threshold={ + "line": {"color": "red", "width": 4}, + "thickness": 1, + "value": threshold, + }, + ), + ) + ) + fig.update_layout(margin=dict(t=40, b=20, l=0, r=0)) + self.browser.setHtml(fig.to_html(include_plotlyjs="cdn")) + + def reset(self): + self.browser.setHtml("") + + +class APrioriGraphPlot(QWidget): + def __init__(self): + super().__init__() + self.browser = QWebEngineView(self) + layout = QVBoxLayout(self) + layout.addWidget(self.browser) + self.resize(1000, 800) + + def plot_set(self, set_): + graph = nx.complete_graph(len(set_)) + nx.relabel_nodes(graph, dict(enumerate(set_)), copy=False) + self.plot_graph(graph, nx.spring_layout(graph), set_) + + def plot_rule(self, set_a, set_b): + graph_a = nx.complete_graph(len(set_a)) + nx.relabel_nodes(graph_a, dict(enumerate(set_a)), copy=False) + + graph_b = nx.complete_graph(len(set_b)) + nx.relabel_nodes(graph_b, dict(enumerate(set_b)), copy=False) + + graph = nx.compose(graph_a, graph_b) + self.plot_graph( + graph, nx.spring_layout(graph, k=0.15, iterations=20), set_a + set_b + ) + + def plot_graph(self, graph: nx.Graph, pos, labels): + edge_x = [] + edge_y = [] + for edge in graph.edges(): + x0, y0 = pos[edge[0]] + x1, y1 = pos[edge[1]] + edge_x += [x0, x1, None] + edge_y += [y0, y1, None] + + edge_trace = go.Scatter( + x=edge_x, + y=edge_y, + line=dict(width=4, color="#054a91"), + mode="lines", + hoverinfo="skip", + ) + + node_trace = go.Scatter( + x=[pos[node][0] for node in graph.nodes()], + y=[pos[node][1] for node in graph.nodes()], + mode="markers+text", + textposition="top center", + hoverinfo="text", + text=labels, + marker=dict( + color="#054a91", + size=20, + ), + ) + + fig = go.Figure( + data=[edge_trace, node_trace], + layout=go.Layout( + showlegend=False, + paper_bgcolor="#dbe4ee", + margin=dict(t=0, b=0, l=0, r=0), + xaxis=dict( + showgrid=False, zeroline=False, showticklabels=False, range=[-2, 2] + ), + yaxis=dict( + showgrid=False, zeroline=False, showticklabels=False, range=[-2, 2] + ), + ), + ) + + self.browser.setHtml(fig.to_html(include_plotlyjs="cdn")) + + def reset(self): + self.browser.setHtml("") + + def show_placeholder(self): + self.browser.setHtml( + 'Click on a frequent set or an ' + "association rule in one of the tables to see visualization " + ) diff --git a/src/visualization/chart.py b/src/visualization/chart.py deleted file mode 100644 index 5da99c3..0000000 --- a/src/visualization/chart.py +++ /dev/null @@ -1,23 +0,0 @@ -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg -from matplotlib.figure import Figure - - -class ChartCanvas(FigureCanvasQTAgg): - - def __init__(self, width: int = 5, height: int = 5, dpi: int = 100, data_width: int = 20): - self.data_width = data_width - fig = Figure(figsize=(width, height), dpi=dpi) - self.axes = fig.add_subplot(111) - super(ChartCanvas, self).__init__(fig) - self.x_data = [] - self.y_data = [] - - def add_number(self, value: int): - next_index = self.x_data[-1] + 1 if len(self.x_data) else 0 - self.x_data.append(next_index) - self.y_data.append(value) - self.x_data = self.x_data[-1 * self.data_width:] - self.y_data = self.y_data[-1 * self.data_width:] - self.axes.cla() - self.axes.plot(self.x_data, self.y_data) - self.draw() diff --git a/src/visualization/clustering_canvas.py b/src/visualization/clustering_canvas.py new file mode 100644 index 0000000..b61beaf --- /dev/null +++ b/src/visualization/clustering_canvas.py @@ -0,0 +1,330 @@ +import numpy as np +from matplotlib import pyplot as plt +from matplotlib import transforms +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.patches import Ellipse + + +class ClusteringCanvas(FigureCanvasQTAgg): + def __init__(self, animation): + fig, self.axes = plt.subplots() + self.animation = animation + self.sc = None + super().__init__(fig) + + def data_plot( + self, + vector_x, + vector_y, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.sc = self.axes.scatter(x=vector_x, y=vector_y) + if drawing: + self.draw() + if self.animation: + return self.axes.collections + + def all_plot( + self, + vector_x, + vector_y, + vector_x_centroids, + vector_y_centroids, + labels, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + label = [labels[i] for i in range(len(vector_x))] + if vector_x_centroids is not None: + max_label = len(vector_x_centroids) + else: + max_label = max(label) + 1 + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.sc = self.axes.scatter( + vector_x, vector_y, c=label, cmap="gist_rainbow", vmin=0, vmax=max_label + ) + if vector_x_centroids is not None: + self.axes.scatter( + vector_x_centroids, + vector_y_centroids, + c=np.arange(max_label), + marker="s", + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + edgecolor="black", + linewidths=1, + ) + if drawing: + self.draw() + if self.animation: + return self.axes.collections + + def new_centroids_plot( + self, + old_vector_x_centroids, + old_vector_y_centroids, + vector_x_centroids, + vector_y_centroids, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + max_label = len(vector_x_centroids) + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + if old_vector_x_centroids is not None: + self.axes.scatter( + old_vector_x_centroids, + old_vector_y_centroids, + c=np.arange(max_label), + marker="s", + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + alpha=0.3, + ) + self.axes.scatter( + vector_x_centroids, + vector_y_centroids, + c=np.arange(max_label), + marker="s", + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + edgecolor="black", + linewidths=1, + ) + if drawing: + self.draw() + if self.animation: + return self.axes.collections + + def chosen_centroid_plot( + self, + vector_x, + vector_y, + other_x, + other_y, + old_x_centroid, + old_y_centroid, + x_centroid, + y_centroid, + label, + max_label, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.axes.scatter(other_x, other_y, c="white", edgecolor="grey") + self.axes.scatter( + vector_x, + vector_y, + c=[label] * len(vector_x), + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + alpha=0.9, + ) + self.axes.scatter( + [old_x_centroid], [old_y_centroid], c="black", marker="s", alpha=0.3, s=40 + ) + self.axes.scatter( + [x_centroid], + [y_centroid], + c=[label], + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + edgecolor="black", + linewidths=1, + marker="s", + alpha=0.7, + s=50, + ) + if drawing: + self.draw() + if self.animation: + return self.axes.collections + + def _draw_variance(self, mean, sigma, label, max_label, n_std=2.0): + pearson = sigma[0][1] / np.sqrt(sigma[0][0] * sigma[1][1]) + ell_radius_x = np.sqrt(1 + pearson) + ell_radius_y = np.sqrt(1 - pearson) + cmap = plt.get_cmap("gist_rainbow") + ellipse = Ellipse( + (0, 0), + width=ell_radius_x * 2, + height=ell_radius_y * 2, + edgecolor=cmap(label / max_label), + facecolor="none", + ) + scale_x = np.sqrt(sigma[0][0]) * n_std + scale_y = np.sqrt(sigma[1][1]) * n_std + transf = ( + transforms.Affine2D() + .rotate_deg(45) + .scale(scale_x, scale_y) + .translate(mean[0], mean[1]) + ) + ellipse.set_transform(transf + self.axes.transData) + self.axes.add_patch(ellipse) + + def clusters_plot( + self, + vector_x, + vector_y, + columns, + mean, + sigma, + labels, + max_label, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.axes.scatter( + vector_x, vector_y, c=labels, cmap="gist_rainbow", vmin=0, vmax=max_label + ) + x_index, y_index = [columns.index(name_x), columns.index(name_y)] + for i in range(len(mean)): + mean_i = [mean[i][x_index], mean[i][y_index]] + sigma_i = [ + [sigma[i][x_index][x_index], sigma[i][x_index][y_index]], + [sigma[i][y_index][x_index], sigma[i][y_index][y_index]], + ] + self.axes.scatter( + mean_i[0], + mean_i[1], + c=i, + cmap="gist_rainbow", + marker="s", + vmin=0, + vmax=max_label, + ) + self._draw_variance(mean_i, sigma_i, i, max_label) + + if drawing: + self.draw() + if self.animation: + return self.axes.collections + + def chosen_cluster_plot( + self, + vector_x, + vector_y, + mean, + sigma, + label, + max_label, + name_x, + name_y, + min_x, + max_x, + min_y, + max_y, + drawing=True, + ): + self.axes.cla() + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.axes.scatter( + vector_x, + vector_y, + c=[label] * len(vector_x), + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + alpha=0.9, + ) + self.axes.scatter( + [mean[0]], + [mean[1]], + c=[label], + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + edgecolor="black", + linewidths=1, + marker="s", + alpha=0.7, + s=50, + ) + self._draw_variance(mean, sigma, label, max_label, n_std=1.0) + if drawing: + self.draw() + + def clusters_means_plot( + self, means, sigmas, name_x, name_y, min_x, max_x, min_y, max_y, drawing=True + ): + self.axes.cla() + x_means, y_means = means + max_label = len(x_means) + self.axes.set_xlabel(name_x) + self.axes.set_ylabel(name_y) + self.axes.set_xlim(min_x, max_x) + self.axes.set_ylim(min_y, max_y) + self.axes.scatter( + x_means, + y_means, + c=np.arange(max_label), + marker="s", + cmap="gist_rainbow", + vmin=0, + vmax=max_label, + edgecolor="black", + linewidths=1, + ) + for i in range(len(x_means)): + self._draw_variance( + [x_means[i], y_means[i]], sigmas[i], i, max_label, n_std=1.0 + ) + if drawing: + self.draw() diff --git a/src/visualization/plots/__init__.py b/src/visualization/plots/__init__.py new file mode 100644 index 0000000..9cef22f --- /dev/null +++ b/src/visualization/plots/__init__.py @@ -0,0 +1,6 @@ +from .plot import Plot +from .fallback import FallbackPlot +from .histogram_plot import HistogramPlot +from .null_freq_plot import NullFrequencyPlot +from .pie_plot import PiePlot +from .scatter_plot import ScatterPlot diff --git a/src/visualization/plots/fallback.py b/src/visualization/plots/fallback.py new file mode 100644 index 0000000..eb907c8 --- /dev/null +++ b/src/visualization/plots/fallback.py @@ -0,0 +1,10 @@ +from visualization.plots import Plot + + +class FallbackPlot(Plot): + def __init__(self, data): + super().__init__(data) + + def plot(self): + self.canvas.figure.subplots() + return self.canvas diff --git a/src/visualization/plots/histogram_plot.py b/src/visualization/plots/histogram_plot.py new file mode 100644 index 0000000..bd96558 --- /dev/null +++ b/src/visualization/plots/histogram_plot.py @@ -0,0 +1,31 @@ +from PyQt5.QtWidgets import QMessageBox + +from visualization.plots import Plot + + +class HistogramPlot(Plot): + def __init__(self, data): + super().__init__(data) + + def plot(self): + try: + ax = self.canvas.figure.subplots() + stats = self.data.value_counts().to_dict() + labels, values = zip(*stats.items()) + ax.bar(labels, values, align="center") + if len(labels) > self.max_labels_show: + self._reduce_labels(ax) + return self.canvas + except Exception: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("Cannot use that plot type for selected column") + error.setWindowTitle("Error") + error.exec_() + return + + def _reduce_labels(self, ax): + every_nth = len(ax.xaxis.get_ticklabels()) // self.max_labels_show + 1 + for n, label in enumerate(ax.xaxis.get_ticklabels()): + if n % every_nth != 0: + label.set_visible(False) diff --git a/src/visualization/plots/null_freq_plot.py b/src/visualization/plots/null_freq_plot.py new file mode 100644 index 0000000..52d3bb7 --- /dev/null +++ b/src/visualization/plots/null_freq_plot.py @@ -0,0 +1,31 @@ +from PyQt5.QtWidgets import QMessageBox + +from visualization.plots import Plot + + +class NullFrequencyPlot(Plot): + def __init__(self, data): + super().__init__(data) + + def plot(self): + try: + ax = self.canvas.figure.subplots() + nulls = self.data.isna().sum() / len(self.data) + not_nulls = 1 - nulls + values = [] + labels = [] + if nulls: + values.append(nulls) + labels.append("NULLs") + if not_nulls: + values.append(not_nulls) + labels.append("Not NULLs") + ax.pie(values, labels=labels, autopct="%1.0f%%") + return self.canvas + except Exception: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("Cannot use that plot type for selected column") + error.setWindowTitle("Error") + error.exec_() + return diff --git a/src/visualization/plots/pie_plot.py b/src/visualization/plots/pie_plot.py new file mode 100644 index 0000000..51c8790 --- /dev/null +++ b/src/visualization/plots/pie_plot.py @@ -0,0 +1,39 @@ +from PyQt5.QtWidgets import QMessageBox + +from visualization.plots import Plot + + +class PiePlot(Plot): + def __init__(self, data): + super().__init__(data) + + def plot(self): + try: + self.data.dropna(inplace=True) + ax = self.canvas.figure.subplots() + counts = self.data.value_counts().to_dict() + first_key = next(iter(counts.keys())) + data_size = self.data.size + labels = [ + k + for k in counts.keys() + if counts[k] / data_size > self.min_pie_plot_label_ratio + or k == first_key + ] + values = [counts[label] for label in labels] + values_sum = sum(values) + if values_sum < data_size: + labels.append("Other") + values.append(data_size - values_sum) + ax.pie(values, labels=labels) + return self.canvas + except Exception: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("Cannot use that plot type for selected column") + error.setWindowTitle("Error") + error.exec_() + return + + def _reduce_labels(self, ax, frequency_ratio): + pass diff --git a/src/visualization/plots/plot.py b/src/visualization/plots/plot.py new file mode 100644 index 0000000..987dbc5 --- /dev/null +++ b/src/visualization/plots/plot.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure +from PyQt5.QtWidgets import QVBoxLayout + + +class Plot(ABC): + def __init__(self, data): + self.data = data + self.max_labels_show = 10 + self.min_pie_plot_label_ratio = 0.03 + self.plot_box = QVBoxLayout() + self.figure = Figure(figsize=(15, 6)) + self.canvas = FigureCanvasQTAgg(self.figure) + + @abstractmethod + def plot(self): + pass diff --git a/src/visualization/plots/scatter_plot.py b/src/visualization/plots/scatter_plot.py new file mode 100644 index 0000000..7890e9b --- /dev/null +++ b/src/visualization/plots/scatter_plot.py @@ -0,0 +1,109 @@ +import pandas as pd +from PyQt5.QtWidgets import QMessageBox + +from visualization import ClusteringCanvas +from visualization.plots import Plot + + +class ScatterPlot(Plot): + def __init__(self, data, settings=None): + super().__init__(data.select_dtypes(include=["number"])) + self.all_data = data + self.settings = settings + self.annot = None + self.connection = None + + def plot(self): + if self.data.shape[1] < 2: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText( + "Cannot use that plot type for this data (not enough numeric columns)" + ) + error.setWindowTitle("Error") + error.exec_() + return + self.canvas = ClusteringCanvas(False) + ox = self.settings["ox"] + oy = self.settings["oy"] + samples = self.settings["samples"] + group_by = self.settings["group_by"] + + samples_data = self.data.iloc[samples] + x = samples_data[ox] + y = samples_data[oy] + min_x = self.data[ox].min() + max_x = self.data[ox].max() + min_y = self.data[oy].min() + max_y = self.data[oy].max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + if group_by: + labels = self.all_data.iloc[samples][group_by] + num_labels = pd.factorize(labels, sort=True)[0] + self.canvas.all_plot( + x, + y, + None, + None, + num_labels, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + drawing=False, + ) + else: + self.canvas.data_plot( + x, + y, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + drawing=False, + ) + self.annot = self.canvas.axes.annotate( + "", + xy=(0, 0), + xytext=(-20, 0), + textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + ) + self.annot.set_visible(True) + self.canvas.hover = self.hover + self.canvas.figure.canvas.mpl_connect("motion_notify_event", self.hover) + self.canvas.draw() + return self.canvas + + def hover(self, event): + vis = self.annot.get_visible() + if event.inaxes == self.canvas.axes: + cont, ind = self.canvas.sc.contains(event) + if cont: + self.update_annot(ind) + self.annot.set_visible(True) + self.canvas.figure.canvas.draw_idle() + else: + if vis: + self.annot.set_visible(False) + self.canvas.figure.canvas.draw_idle() + + def update_annot(self, ind): + pos = self.canvas.sc.get_offsets()[ind["ind"][0]] + self.annot.xy = pos + samples = self.settings["samples"] + samples_data = self.all_data.iloc[samples].iloc[ind["ind"]] + text = "\n---\n".join( + "\n".join([f"{key}: {item}" for key, item in row.items()]) + for _, row in samples_data.iterrows() + ) + self.annot.set_text(text) + self.annot.get_bbox_patch().set_alpha(0.4) + self.annot.set_visible(True) + self.canvas.figure.canvas.draw_idle() diff --git a/src/widgets/__init__.py b/src/widgets/__init__.py index e69de29..fe855ed 100644 --- a/src/widgets/__init__.py +++ b/src/widgets/__init__.py @@ -0,0 +1,13 @@ +from .config import UNFOLD_BUTTON_WIDTH +from .table_model import QtTable +from .loading_widget import LoadingWidget +from .table_model import QtTable +from .tables import MergingSetsScreen +from .unfold_widgets.unfold_widget import UnfoldWidget +from .unfold_widgets.algorithm_run_widget import AlgorithmRunWidget +from .unfold_widgets.algorithm_setup_widget import AlgorithmSetupWidget +from .unfold_widgets.import_widget import ImportWidget +from .unfold_widgets.preprocessing_widget import PreprocessingWidget +from .unfold_widgets.results_widget import ResultsWidget +from .main_widget import MainWidget +from .main_window import MainWindow diff --git a/src/widgets/chart_widget.py b/src/widgets/chart_widget.py deleted file mode 100644 index f07e2b7..0000000 --- a/src/widgets/chart_widget.py +++ /dev/null @@ -1,14 +0,0 @@ -from PyQt5.QtWidgets import QVBoxLayout -from visualization.chart import ChartCanvas - - -class ChartWidget(QVBoxLayout): - - def __init__(self): - super().__init__() - - self.canvas = ChartCanvas() - self.addWidget(self.canvas) - - def display_number(self, value: int): - self.canvas.add_number(value) diff --git a/src/widgets/components/__init__.py b/src/widgets/components/__init__.py new file mode 100644 index 0000000..0a5ca73 --- /dev/null +++ b/src/widgets/components/__init__.py @@ -0,0 +1,5 @@ +from .samples_column_choices import SamplesColumnsChoice +from .clustering_steps import ClusteringStepsTemplate +from .clustering_table import ClustersTable +from .parameters import ParametersGroupBox +from .tooltip_widget import QLabelWithTooltip diff --git a/src/widgets/components/clustering_steps.py b/src/widgets/components/clustering_steps.py new file mode 100644 index 0000000..7716321 --- /dev/null +++ b/src/widgets/components/clustering_steps.py @@ -0,0 +1,210 @@ +from functools import partial +from typing import Callable, List + +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from PyQt5.QtCore import Qt, pyqtSignal +from PyQt5.QtWidgets import ( + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QPushButton, + QScrollArea, + QSizePolicy, + QSpinBox, + QVBoxLayout, + QWidget, +) + +from widgets.components import SamplesColumnsChoice + + +class ClusteringStepsTemplate(QWidget): + parameters_changed = pyqtSignal() + + def __init__( + self, + columns: List[str], + max_step: int, + size: int, + description: str, + is_animation: bool, + canvas: FigureCanvasQTAgg, + get_func_animation: Callable, + ): + super().__init__() + + self.is_running = False + self.is_animation = is_animation + self.animation = None + self.get_func_animation = get_func_animation + self.layout = QHBoxLayout(self) + + self.max_step = max_step + self.current_step = 0 + + # left column layout + self.left_column_layout = QVBoxLayout() + + # settings layout + self.settings_box = QGroupBox() + self.settings_box.setTitle("Settings") + self.settings_box.setFixedWidth(250) + self.settings_box_layout = QVBoxLayout(self.settings_box) + + # samples columns choice + self.parameters_widget = SamplesColumnsChoice(columns, size) + self.parameters_widget.samples_columns_changed.connect( + partial(self.click_listener, "parameters_changed") + ) + self.settings_box_layout.addWidget(self.parameters_widget) + self.ox = self.parameters_widget.ox + self.oy = self.parameters_widget.oy + self.samples = self.parameters_widget.samples + + self.left_column_layout.addWidget(self.settings_box, 0) + + # visualization layout + self.visualization_box = QGroupBox() + self.visualization_box.setTitle("Visualization") + self.visualization_box_layout = QVBoxLayout(self.visualization_box) + + if self.is_animation: + # animation + self.animation_box = QGroupBox() + self.animation_box.setTitle("Animation") + self.animation_box.setFixedWidth(250) + self.animation_box_layout = QFormLayout(self.animation_box) + + self.restart_button = QPushButton("Restart") + self.restart_button.clicked.connect(partial(self.click_listener, "restart")) + self.run_button = QPushButton("Start animation") + self.run_button.clicked.connect(partial(self.click_listener, "run")) + self.interval_box = QSpinBox() + self.interval_box.setMinimum(20) + self.interval_box.setMaximum(2000) + self.interval_box.setValue(200) + self.interval_box.setSingleStep(20) + + self.animation_box_layout.addRow( + QLabel("Interval time [ms]:"), self.interval_box + ) + self.animation_box_layout.addRow(self.restart_button) + self.animation_box_layout.addRow(self.run_button) + + self.left_column_layout.addWidget(self.animation_box, 0) + + # plot + self.canvas = canvas + self.visualization_box_layout.addWidget(self.canvas, 1) + + if not self.is_animation: + self.visualization_box_layout.addStretch() + + # control buttons + self.control_buttons_layout = QHBoxLayout() + self.left_box = QSpinBox() + self.left_box.setMinimum(1) + self.right_box = QSpinBox() + self.right_box.setMinimum(1) + self.left_button = QPushButton("PREV") + self.left_button.clicked.connect(partial(self.click_listener, "prev")) + self.right_button = QPushButton("NEXT") + self.right_button.clicked.connect(partial(self.click_listener, "next")) + self.step_label = QLabel("STEP: {}".format(self.current_step)) + self.control_buttons_layout.addWidget(self.left_button) + self.control_buttons_layout.addWidget(self.left_box) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.step_label) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.right_box) + self.control_buttons_layout.addWidget(self.right_button) + + self.visualization_box_layout.addLayout(self.control_buttons_layout, 0) + else: + self.step_label = QLabel("STEP: {}".format(self.current_step)) + self.visualization_box_layout.addWidget( + self.step_label, 0, alignment=Qt.AlignCenter + ) + + self.description_label = QLabel(description) + self.description_label.setWordWrap(True) + + self.description_group_box = QGroupBox() + self.description_group_box.setFixedWidth(250) + self.description_group_box.setTitle("Description") + self.description_group_box_layout = QVBoxLayout(self.description_group_box) + + self.scroll_box = QGroupBox() + self.scroll_box_layout = QFormLayout(self.scroll_box) + self.scroll = QScrollArea() + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.scroll.setWidget(self.scroll_box) + self.scroll.setWidgetResizable(True) + self.scroll.setMinimumHeight(26) + self.description_group_box_layout.addWidget(self.scroll) + + self.scroll_box_layout.addWidget(self.description_label) + + self.left_column_layout.addWidget(self.description_group_box, 1) + + self.left_column_layout.setSpacing(35) + + self.layout.addLayout(self.left_column_layout) + self.layout.addWidget(self.visualization_box) + + def click_listener(self, button_type: str): + match button_type: + case "parameters_changed": + self.ox = self.parameters_widget.ox + self.oy = self.parameters_widget.oy + self.samples = self.parameters_widget.samples + self.parameters_changed.emit() + case "prev": + num = self.left_box.value() + self.change_step(-1 * num) + case "next": + num = self.right_box.value() + self.change_step(num) + case "restart": + self.animation = None + self.change_step(-1 * self.current_step) + self.parameters_widget.change_enabled_buttons(True) + self.interval_box.setEnabled(True) + self.run_button.setEnabled(True) + case "run": + self.is_running = not self.is_running + self.parameters_widget.change_enabled_buttons(False) + self.interval_box.setEnabled(False) + self.restart_button.setEnabled(not self.is_running) + if self.is_running: + if self.animation is None: + self.animation = self.get_func_animation() + self.canvas.draw() + else: + self.animation.resume() + self.run_button.setText("Stop animation") + else: + self.animation.pause() + self.run_button.setText("Start animation") + + def change_step(self, change: int): + new_step = max(0, min(self.max_step, self.current_step + change)) + if new_step == self.current_step: + return + self.current_step = new_step + self.step_label.setText("STEP: {}".format(self.current_step)) + self.parameters_changed.emit() + self.step_label.update() + + def end_animation(self): + self.run_button.setText("Start animation") + self.parameters_widget.change_enabled_buttons(True) + self.run_button.setEnabled(False) + self.restart_button.setEnabled(True) + self.is_running = False + + def update_step_label(self): + self.step_label.setText("STEP: {}".format(self.current_step)) diff --git a/src/widgets/components/clustering_table.py b/src/widgets/components/clustering_table.py new file mode 100644 index 0000000..d2b8ec9 --- /dev/null +++ b/src/widgets/components/clustering_table.py @@ -0,0 +1,115 @@ +from functools import partial + +from PyQt5.QtCore import Qt, pyqtSignal +from PyQt5.QtWidgets import ( + QHBoxLayout, + QInputDialog, + QLabel, + QMessageBox, + QPushButton, + QTableView, + QVBoxLayout, + QWidget, +) + +from widgets import QtTable + + +class ClustersTable(QWidget): + table_changed = pyqtSignal() + + def __init__(self, data, labels, clusters_representative, columns_num): + super().__init__() + + self.data = data + self.labels = labels + self.selected_cluster = None + self.clusters_representative = clusters_representative + self.layout = QVBoxLayout(self) + self.clusters_table = QTableView() + self.clusters_table.setModel(QtTable(self.clusters_representative.round(3))) + self.clusters_table.doubleClicked.connect(self.show_cluster) + for i in range(columns_num): + self.clusters_table.setColumnWidth(i, 120) + + self.clusters_table_header = QWidget() + self.clusters_table_header_layout = QHBoxLayout(self.clusters_table_header) + self.clusters_table_instruction = QLabel( + "Double click on any field to preview a cluster" + ) + self.save_all_button = QPushButton("SAVE RESULTS") + self.save_all_button.clicked.connect( + partial(self.on_save_button_click, self.data.assign(cluster=self.labels)) + ) + self.save_all_button.setFixedWidth(120) + self.clusters_table_header_layout.addWidget(self.clusters_table_instruction) + self.clusters_table_header_layout.addWidget(self.save_all_button) + + self.layout.addWidget(self.clusters_table_header) + self.layout.addWidget(self.clusters_table) + + def show_cluster(self): + self.selected_cluster = ( + self.clusters_table.selectionModel().selectedIndexes()[0].row() + ) + rows = [ + i + for i in range(len(self.labels)) + if self.labels[i] == self.selected_cluster + ] + elements = self.data.iloc[rows] + self.clusters_table.setModel(QtTable(elements)) + buttons_widget = QWidget() + buttons_layout = QHBoxLayout() + exit_button = QPushButton("X") + exit_button.clicked.connect(self.exit_from_cluster) + exit_button.setFixedWidth(50) + save_button = QPushButton("SAVE") + save_button.clicked.connect(partial(self.on_save_button_click, elements)) + save_button.setFixedWidth(100) + buttons_layout.addWidget(save_button) + buttons_layout.addWidget(exit_button) + buttons_layout.setAlignment(Qt.AlignRight) + buttons_widget.setLayout(buttons_layout) + self.layout.insertWidget(0, buttons_widget) + self.clusters_table_header.hide() + self.clusters_table.doubleClicked.disconnect() + self.table_changed.emit() + + def exit_from_cluster(self): + self.clusters_table.setModel(QtTable(self.clusters_representative.round(3))) + self.layout.itemAt(0).widget().setParent(None) + self.clusters_table_header.show() + self.clusters_table.doubleClicked.connect(self.show_cluster) + self.selected_cluster = None + self.table_changed.emit() + + def on_save_button_click(self, elements): + path, is_ok = QInputDialog.getText(self, "Save to file", "Enter filename") + if is_ok and path: + if not path.endswith(".csv"): + path += ".csv" + try: + elements.to_csv(path) + except Exception: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText( + "Something wrong happened while writing data to file. Try again." + ) + error.setWindowTitle("Saving failed") + error.exec_() + elif not is_ok: + pass + elif not path: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("No path was provided") + error.setWindowTitle("Empty path") + error.exec_() + else: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("This file extension is not supported.") + error.setWindowTitle("Unsupported extension") + error.exec_() diff --git a/src/widgets/components/parameters.py b/src/widgets/components/parameters.py new file mode 100644 index 0000000..0db6979 --- /dev/null +++ b/src/widgets/components/parameters.py @@ -0,0 +1,13 @@ +from typing import Dict + +from PyQt5.QtWidgets import QFormLayout, QGroupBox, QLabel + + +class ParametersGroupBox(QGroupBox): + def __init__(self, info: Dict[str, any], title: str = "Parameters"): + super().__init__() + + self.setTitle(title) + self.layout = QFormLayout(self) + for option, value in info.items(): + self.layout.addRow(QLabel(f"{option}:"), QLabel(f"{value}")) diff --git a/src/widgets/components/samples_column_choices.py b/src/widgets/components/samples_column_choices.py new file mode 100644 index 0000000..f203ead --- /dev/null +++ b/src/widgets/components/samples_column_choices.py @@ -0,0 +1,109 @@ +from functools import partial + +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtWidgets import ( + QComboBox, + QFormLayout, + QLabel, + QPushButton, + QSpinBox, + QWidget, +) + +from algorithms import get_samples + + +class SamplesColumnsChoice(QWidget): + samples_columns_changed = pyqtSignal() + samples_changed = pyqtSignal() + + def __init__(self, columns=None, size=0): + super().__init__() + + if columns is None or len(columns) == 0: + columns = [""] + + self.layout = QFormLayout(self) + + self.size = size + self.num_samples = min(35, self.size // 2) + self.samples = get_samples(self.size, self.num_samples) + + self.ox = columns[0] + self.oy = columns[0] if len(columns) < 2 else columns[1] + + # samples + self.layout.addRow(QLabel("Set samples:")) + self.sample_box = QSpinBox() + self.sample_box.setMinimum(0) + self.sample_box.setMaximum(min(self.size, 10000)) + self.sample_box.setProperty("value", self.num_samples) + self.sample_button = QPushButton("Refresh samples") + self.sample_button.clicked.connect(partial(self.click_listener, "new_samples")) + self.layout.addRow(self.sample_box, self.sample_button) + + # axis + self.layout.addRow(QLabel("Set axes:")) + self.ox_box = QComboBox() + self.ox_box.addItems(columns) + self.oy_box = QComboBox() + self.oy_box.addItems(columns) + + self.ox_box.setMinimumWidth(100) + self.oy_box.setMinimumWidth(100) + + if len(columns) > 1: + self.oy_box.setCurrentIndex(1) + self.ox_box.currentTextChanged.connect(partial(self.click_listener, "set_axis")) + self.oy_box.currentTextChanged.connect(partial(self.click_listener, "set_axis")) + self.layout.addRow(QLabel("OX:"), self.ox_box) + self.layout.addRow(QLabel("OY:"), self.oy_box) + + def click_listener(self, button_type: str): + match button_type: + case "new_samples": + num = self.sample_box.value() + self.num_samples = num + self.samples = get_samples(self.size, self.num_samples) + self.samples_columns_changed.emit() + self.samples_changed.emit() + case "set_axis": + self.ox = self.ox_box.currentText() + self.oy = self.oy_box.currentText() + self.samples_columns_changed.emit() + + def new_columns_name(self, columns): + self.ox = columns[0] + self.oy = columns[0] if len(columns) < 2 else columns[1] + + self.ox_box.clear() + self.oy_box.clear() + self.ox_box.addItems(columns) + self.oy_box.addItems(columns) + if len(columns) > 1: + self.oy_box.setCurrentIndex(1) + self.samples_columns_changed.emit() + + def new_size(self, size): + self.size = size + self.num_samples = min(100, self.size // 2) + self.samples = get_samples(self.size, self.num_samples) + self.sample_box.setMaximum(min(self.size, 10000)) + self.sample_box.setProperty("value", self.num_samples) + self.samples_columns_changed.emit() + self.samples_changed.emit() + + def get_parameters(self): + parameters = {"ox": self.ox, "oy": self.oy, "samples": self.samples} + return parameters + + def change_enabled_buttons(self, value): + self.ox_box.setEnabled(value) + self.oy_box.setEnabled(value) + self.sample_button.setEnabled(value) + self.sample_box.setEnabled(value) + + def reset(self): + self.ox_box.clear() + self.oy_box.clear() + self.sample_box.clear() diff --git a/src/widgets/components/tooltip_widget.py b/src/widgets/components/tooltip_widget.py new file mode 100644 index 0000000..40a1446 --- /dev/null +++ b/src/widgets/components/tooltip_widget.py @@ -0,0 +1,23 @@ +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QIcon +from PyQt5.QtWidgets import QHBoxLayout, QLabel, QWidget + + +class QLabelWithTooltip(QWidget): + INFO_ICON_FILE = "../static/img/info_icon.svg" + + def __init__(self, label: str, description: str = ""): + super().__init__() + + self.layout = QHBoxLayout(self) + self.layout.setContentsMargins(0, 0, 0, 0) + + self.label = QLabel(label) + self.layout.addWidget(self.label, alignment=Qt.AlignVCenter) + + if description: + self.icon = QIcon(self.INFO_ICON_FILE).pixmap(14) + self.icon_label = QLabel() + self.icon_label.setPixmap(self.icon) + self.icon_label.setToolTip(description) + self.layout.addWidget(self.icon_label) diff --git a/src/widgets/config.py b/src/widgets/config.py new file mode 100644 index 0000000..c74d988 --- /dev/null +++ b/src/widgets/config.py @@ -0,0 +1 @@ +UNFOLD_BUTTON_WIDTH = 45 diff --git a/src/widgets/data_generator_widget.py b/src/widgets/data_generator_widget.py new file mode 100644 index 0000000..161a422 --- /dev/null +++ b/src/widgets/data_generator_widget.py @@ -0,0 +1,227 @@ +from functools import partial +from typing import Callable, Dict, Tuple, Type + +import matplotlib.pyplot as plt +from PyQt5.QtCore import QRect +from PyQt5.QtWidgets import ( + QComboBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QMessageBox, + QPushButton, + QSizePolicy, + QTableView, + QVBoxLayout, + QWidget, +) + +from data_generators import ( + DataGeneratorFunction, + noncentral_f_blobs_generator, + normal_distribution_blobs_generator, +) +from engines import ImportDataEngine +from visualization.plots import ScatterPlot +from widgets import QtTable +from widgets.components import SamplesColumnsChoice +from widgets.options_widgets import ( + AlgorithmOptions, + NoncentralFClusteringOptions, + NormalDistributionClusteringOptions, +) + + +class DataGeneratorWidget(QWidget): + def __init__(self, engine: ImportDataEngine, callback: Callable): + super().__init__() + self.engine = engine + self.callback = callback + self.plot_connection = None + + self.generated_data = None + self.setWindowTitle("Data generator") + self.setGeometry(QRect(400, 400, 800, 400)) + + self.setObjectName("data_generator_widget") + with open("../static/css/styles.css") as stylesheet: + self.setStyleSheet(stylesheet.read()) + + self.dataset_types_config: Dict[ + str, Tuple[DataGeneratorFunction, Type[AlgorithmOptions]] + ] = { + "(Clustering) Normal distribution blobs": ( + normal_distribution_blobs_generator, + NormalDistributionClusteringOptions, + ), + "(Clustering) Noncental F distribution blobs": ( + noncentral_f_blobs_generator, + NoncentralFClusteringOptions, + ), + } + + self.layout = QHBoxLayout() + self._render_algorithm_selection() + self._render_options() + + self.generate_button = QPushButton(self) + self.generate_button.setText("Generate") + self.generate_button.clicked.connect(partial(self.click_listener, "generate")) + + self.load_button = QPushButton(self) + self.load_button.setText("Load") + self.load_button.clicked.connect(partial(self.click_listener, "load")) + self.load_button.setEnabled(False) + + self.cancel_button = QPushButton(self) + self.cancel_button.setText("Cancel") + self.cancel_button.clicked.connect(partial(self.click_listener, "cancel")) + + self._render_data() + self._render_scatter_plot() + + self.left_column = QVBoxLayout() + self.left_column.addWidget(self.algorithm_group) + self.left_column.addStretch() + self.left_column.addWidget(self.options_group) + self.left_column.addStretch() + self.left_column.addWidget(self.generate_button) + self.left_column.addWidget(self.load_button) + self.left_column.addWidget(self.cancel_button) + + self.right_column = QVBoxLayout() + self.right_column.addWidget(self.data_group, 2) + self.right_column.addWidget(self.scatter_plot_group, 1) + + self.layout.addLayout(self.left_column, 0) + self.layout.addLayout(self.right_column, 1) + + self.hide() + self.setLayout(self.layout) + + def _render_algorithm_selection(self): + self.algorithm_group = QGroupBox(self) + self.algorithm_group.setTitle("Algorithm selection") + self.algorithm_group_layout = QVBoxLayout(self.algorithm_group) + + self.dataset_type_label = QLabel("Select dataset type:") + self.dataset_type_box = QComboBox(self.algorithm_group) + self.dataset_type_box.addItems(self.dataset_types_config.keys()) + + self.dataset_type_box.currentTextChanged.connect( + partial(self.click_listener, "dataset_type") + ) + + self.algorithm_group_layout.addWidget(self.dataset_type_label) + self.algorithm_group_layout.addWidget(self.dataset_type_box) + + def _render_options(self): + self.options_group = QGroupBox(self) + self.options_group.setTitle("Options") + self.options_group_layout = QFormLayout(self.options_group) + self._set_data_generator(self.dataset_type_box.currentText()) + + def _set_data_generator(self, dataset_type): + self.options_widget = self.dataset_types_config[dataset_type][1]() + + if item := self.options_group_layout.itemAt(0): + self.options_group_layout.removeWidget(item.widget()) + + self.options_group_layout.addWidget(self.options_widget) + self.selected_generator = self.dataset_types_config[dataset_type][0] + + def _render_data(self): + self.data_group = QGroupBox(self) + self.data_group.setTitle("Generated data") + self.data_group_layout = QVBoxLayout(self.data_group) + + self.data_table = QTableView(self.data_group) + self.data_table.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.data_group_layout.addWidget(self.data_table) + + def _render_scatter_plot(self): + self.scatter_plot_group = QGroupBox(self) + self.scatter_plot_group.setTitle("Plot") + self.scatter_plot_group_layout = QHBoxLayout(self.scatter_plot_group) + + self.parameters_widget = SamplesColumnsChoice() + + self.plot_layout = QVBoxLayout() + + self.scatter_plot_group_layout.addWidget(self.parameters_widget, 0) + self.scatter_plot_group_layout.addLayout(self.plot_layout, 1) + + def _plot_data(self): + self._reset_plot() + if self.generated_data is not None: + parameters = self.parameters_widget.get_parameters() + parameters["group_by"] = None + self.plot = ScatterPlot(self.generated_data, parameters).plot() + self.plot_layout.addWidget(self.plot) + + def _reset_plot(self): + if item := self.plot_layout.itemAt(0): + item.widget().setParent(None) + plt.close(self.plot.figure) + + def _reset(self): + self.generated_data = None + self.data_table.setModel(None) + self.load_button.setEnabled(False) + self.parameters_widget.change_enabled_buttons(False) + self.parameters_widget.reset() + + self.dataset_type_box.setCurrentIndex(0) + self._set_data_generator(self.dataset_type_box.currentText()) + + self._reset_plot() + + def click_listener(self, button_type: str): + match button_type: + case "dataset_type": + self._set_data_generator(self.dataset_type_box.currentText()) + case "generate": + try: + options = self.options_widget.get_data() + except Exception as e: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText( + f"The format of the set options is incorrect. {f'({e})' if e else ''}" + ) + error.setWindowTitle("Error") + error.exec_() + else: + self.generated_data = self.selected_generator(options) + self.data_table.setModel(QtTable(self.generated_data)) + self.load_button.setEnabled(self.generated_data is not None) + if self.plot_connection: + self.parameters_widget.samples_columns_changed.disconnect( + self.plot_connection + ) + self.plot_connection = None + if self.generated_data is not None: + self.parameters_widget.change_enabled_buttons(True) + self.parameters_widget.new_columns_name( + list(self.generated_data.columns) + ) + self.parameters_widget.new_size(len(self.generated_data)) + self.plot_connection = ( + self.parameters_widget.samples_columns_changed.connect( + self._plot_data + ) + ) + self._plot_data() + case "cancel": + self._reset() + self.hide() + case "load": + self.engine.set_generated_data(self.generated_data) + self._reset() + self.callback() + self.hide() + + def show(self) -> None: + self._reset() + super().show() diff --git a/src/widgets/generate_widget.py b/src/widgets/generate_widget.py deleted file mode 100644 index f252334..0000000 --- a/src/widgets/generate_widget.py +++ /dev/null @@ -1,18 +0,0 @@ -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import QVBoxLayout, QLabel, QPushButton, QMainWindow - - -class GenerateWidget(QVBoxLayout): - - def __init__(self, parent: QMainWindow): - super().__init__() - - self.label = QLabel('') - self.addWidget(self.label, alignment=Qt.AlignCenter) - - self.button = QPushButton('GENERATE') - self.button.clicked.connect(parent.on_click_listener) - self.addWidget(self.button) - - def display_number(self, value: int): - self.label.setText(str(value)) diff --git a/src/widgets/loading_widget.py b/src/widgets/loading_widget.py new file mode 100644 index 0000000..a9d8c16 --- /dev/null +++ b/src/widgets/loading_widget.py @@ -0,0 +1,23 @@ +from PyQt5.QtCore import QRect, Qt +from PyQt5.QtWidgets import QApplication, QDesktopWidget, QSplashScreen + + +class LoadingWidget: + def __init__(self, callback, *args): + self.screen = QSplashScreen() + self.size = QDesktopWidget().screenGeometry(-1) + self.callback = callback + self.args = args + + def execute(self): + self.screen.showMessage("

Loading...

", Qt.AlignCenter) + self.screen.setGeometry( + QRect(self.size.width() // 2 - 125, self.size.height() // 2 - 50, 250, 100) + ) + self.screen.show() + QApplication.processEvents() + if self.args: + self.callback(*self.args) + else: + self.callback() + self.screen.close() diff --git a/src/widgets/main_layout.py b/src/widgets/main_layout.py deleted file mode 100644 index da2e69a..0000000 --- a/src/widgets/main_layout.py +++ /dev/null @@ -1,32 +0,0 @@ -from PyQt5.QtWidgets import QMainWindow, QHBoxLayout, QWidget - -from algorithms.generate_number import NumberGenerator -from widgets.chart_widget import ChartWidget -from widgets.generate_widget import GenerateWidget - - -class RandomGenerator(QMainWindow): - - def __init__(self): - super().__init__() - self.setWindowTitle('Random Generator') - self.setFixedSize(235*6, 235*2) - self.generalLayout = QHBoxLayout() - self._centralWidget = QWidget(self) - self.setCentralWidget(self._centralWidget) - self._centralWidget.setLayout(self.generalLayout) - - self.generator = NumberGenerator(0, 10) - self.generate_widget = GenerateWidget(self) - self.chart_widget = ChartWidget() - self.generalLayout.addLayout(self.generate_widget) - self.generalLayout.addLayout(self.chart_widget) - - self.show() - - def on_click_listener(self): - value = self.generator.get_number() - self.generate_widget.display_number(value) - self.chart_widget.display_number(value) - - diff --git a/src/widgets/main_widget.py b/src/widgets/main_widget.py new file mode 100644 index 0000000..f3484ad --- /dev/null +++ b/src/widgets/main_widget.py @@ -0,0 +1,57 @@ +from PyQt5.QtWidgets import QWIDGETSIZE_MAX, QHBoxLayout, QWidget + +from widgets import ( + UNFOLD_BUTTON_WIDTH, + AlgorithmRunWidget, + AlgorithmSetupWidget, + ImportWidget, + PreprocessingWidget, + ResultsWidget, +) + + +class MainWidget(QWidget): + def __init__(self, engines): + super().__init__() + + self.import_widget = ImportWidget(self, engines["import_data"]) + self.preprocessing_widget = PreprocessingWidget(self, engines["preprocess"]) + self.algorithm_setup_widget = AlgorithmSetupWidget( + self, engines["algorithm_setup"] + ) + self.algorithm_run_widget = AlgorithmRunWidget(self, engines["algorithm_run"]) + self.results_widget = ResultsWidget(self, engines["results"]) + + self.widgets = { + "import_widget": self.import_widget, + "preprocessing_widget": self.preprocessing_widget, + "algorithm_setup_widget": self.algorithm_setup_widget, + "algorithm_run_widget": self.algorithm_run_widget, + "results_widget": self.results_widget, + } + + layout = QHBoxLayout() + for widget in self.widgets.values(): + layout.addWidget(widget) + + layout.setSpacing(0) + self.setLayout(layout) + + self.unfolded_widget = self.import_widget + self.unfold(self.import_widget) + + def unfold_by_id(self, widget_name): + if widget_name in self.widgets.keys(): + self.widgets[widget_name].load_widget() + + def unfold(self, widget): + self.unfolded_widget.setFixedWidth(UNFOLD_BUTTON_WIDTH) + self.unfolded_widget.frame.setFixedWidth(0) + + widget.setFixedWidth(QWIDGETSIZE_MAX) + widget.frame.setFixedWidth(QWIDGETSIZE_MAX) + + self.unfolded_widget = widget + + if focused := self.focusWidget(): + focused.clearFocus() diff --git a/src/widgets/main_window.py b/src/widgets/main_window.py new file mode 100644 index 0000000..d39460b --- /dev/null +++ b/src/widgets/main_window.py @@ -0,0 +1,27 @@ +from PyQt5.QtWidgets import QDesktopWidget, QHBoxLayout, QMainWindow, QWidget + +from widgets import MainWidget + + +class MainWindow(QMainWindow): + def __init__(self, engines): + super().__init__() + self.setWindowTitle("Data Mining Tool") + self.setGeometry(0, 0, 1200, 600) + + # position the window in the middle of the screen + rect = self.frameGeometry() + rect.moveCenter(QDesktopWidget().availableGeometry().center()) + self.move(rect.topLeft()) + + self.generalLayout = QHBoxLayout() + self._centralWidget = QWidget(self) + self.setCentralWidget(self._centralWidget) + self._centralWidget.setLayout(self.generalLayout) + + with open("../static/css/styles.css") as stylesheet: + self.setStyleSheet(stylesheet.read()) + + self.generalLayout.addWidget(MainWidget(engines)) + + self.show() diff --git a/src/widgets/options_widgets/__init__.py b/src/widgets/options_widgets/__init__.py new file mode 100644 index 0000000..b83b6ce --- /dev/null +++ b/src/widgets/options_widgets/__init__.py @@ -0,0 +1,9 @@ +from .algorithm_options import AlgorithmOptions +from .association_rules_options import AssociationRulesOptions +from .dataset_options import ( + NoncentralFClusteringOptions, + NormalDistributionClusteringOptions, +) +from .extra_trees_options import ExtraTreesOptions +from .gmm_options import GMMOptions +from .k_means_options import KMeansOptions diff --git a/src/widgets/options_widgets/algorithm_options.py b/src/widgets/options_widgets/algorithm_options.py new file mode 100644 index 0000000..a171832 --- /dev/null +++ b/src/widgets/options_widgets/algorithm_options.py @@ -0,0 +1,20 @@ +from abc import abstractmethod + +from PyQt5.QtWidgets import QFormLayout, QWidget + + +class AlgorithmOptions(QWidget): + """ + Widget, which allows to set parameters of algorithm + """ + + def __init__(self): + super().__init__() + self.layout = QFormLayout(self) + + @abstractmethod + def get_data(self) -> dict: + """ + Return dict with parameters + """ + raise NotImplementedError diff --git a/src/widgets/options_widgets/association_rules_options.py b/src/widgets/options_widgets/association_rules_options.py new file mode 100644 index 0000000..2aeec91 --- /dev/null +++ b/src/widgets/options_widgets/association_rules_options.py @@ -0,0 +1,59 @@ +from typing import List + +from PyQt5.QtWidgets import QComboBox, QDoubleSpinBox + +from widgets.components import QLabelWithTooltip +from widgets.options_widgets import AlgorithmOptions + + +class AssociationRulesOptions(AlgorithmOptions): + def __init__(self): + super().__init__() + + self.min_support_spinbox = QDoubleSpinBox() + self.min_support_spinbox.setMinimum(0.01) + self.min_support_spinbox.setValue(0.05) + self.min_support_spinbox.setMaximum(1) + self.min_support_spinbox.setSingleStep(0.1) + self.layout.addRow( + QLabelWithTooltip( + "Minimum support:", + "Minimum percentage of transactions containing a specific subset for it to be considered a frequent set.\n" + "Too low of a value will result in a very big number of sets found which may affect performance.", + ), + self.min_support_spinbox, + ) + + self.min_confidence_spinbox = QDoubleSpinBox() + self.min_confidence_spinbox.setMinimum(0.01) + self.min_confidence_spinbox.setValue(0.1) + self.min_confidence_spinbox.setMaximum(1) + self.layout.addRow( + QLabelWithTooltip( + "Minimum confidence:", + "Confidence of a rule is calculated by dividing the probability of the items occurring together by the probability of the occurrence of the antecedent.\n" + "It signifies how strong the association rule really is.\n" + "Too low of a value may lead to obtaining a high amount of rules.\n" + "Bigger value will limit rules to just the strongest selection.", + ), + self.min_confidence_spinbox, + ) + + self.index_columns_combobox = QComboBox() + self.layout.addRow( + QLabelWithTooltip( + "Index column:", "Column containing each transaction's ID" + ), + self.index_columns_combobox, + ) + + def get_data(self) -> dict: + return { + "min_support": self.min_support_spinbox.value(), + "min_confidence": self.min_confidence_spinbox.value(), + "index_column": self.index_columns_combobox.currentText(), + } + + def set_columns_options(self, columns: List[str]): + self.index_columns_combobox.clear() + self.index_columns_combobox.addItems(columns) diff --git a/src/widgets/options_widgets/dataset_options.py b/src/widgets/options_widgets/dataset_options.py new file mode 100644 index 0000000..e8d4bb2 --- /dev/null +++ b/src/widgets/options_widgets/dataset_options.py @@ -0,0 +1,191 @@ +from typing import Dict + +from PyQt5.QtWidgets import QLabel, QLineEdit, QSpinBox + +from widgets.components import QLabelWithTooltip +from widgets.options_widgets import AlgorithmOptions + + +class ClusteringBlobsDataOptions(AlgorithmOptions): + def __init__(self): + super().__init__() + self.number_of_blobs_box = QSpinBox() + self.number_of_blobs_box.setRange(0, 20) + self.number_of_blobs_box.setValue(2) + self.layout.addRow(QLabel("Number of blobs:"), self.number_of_blobs_box) + + self.sample_size_input = QLineEdit("50 30") + self.layout.addRow( + QLabelWithTooltip( + "Sample sizes per blob:", + "Series of integer values separated by a single space.\n" + "There should be as many numbers listed as the number of blobs.\n" + "If just one value is provided, it will be used for each blob.", + ), + self.sample_size_input, + ) + + self.number_of_dims_box = QSpinBox() + self.number_of_dims_box.setRange(1, 10) + self.number_of_dims_box.setValue(2) + self.layout.addRow(QLabel("Number of dimensions:"), self.number_of_dims_box) + + self.noise_box = QSpinBox() + self.noise_box.setValue(0) + self.noise_box.setMaximum(100) + + self.layout.addRow( + QLabelWithTooltip( + "Additional noise percentage:", + "How many more points that do not fall into desired pattern\n" + "will be added to the generated data set.\n" + "Counted as percentage of provided total sample size.", + ), + self.noise_box, + ) + + self.seed_box = QSpinBox() + self.seed_box.setRange(0, 10) + self.seed_box.setSpecialValueText("random") + + self.layout.addRow( + QLabelWithTooltip( + "Seed:", + "Used for setting state of the randomizer.\n" + "Setting this field to a chosen value will ensure getting the same results every time.", + ), + self.seed_box, + ) + + def get_data(self) -> Dict: + blobs_number = self.number_of_blobs_box.value() + sample_sizes = list(map(int, self.sample_size_input.text().split(" "))) + provided_sample_sizes = len(sample_sizes) + + if provided_sample_sizes != 1 and provided_sample_sizes != blobs_number: + raise ValueError("Incorrect number of provided sample sizes") + + if provided_sample_sizes == 1: + sample_sizes *= blobs_number + + dims_number = self.number_of_dims_box.value() + + seed = self.seed_box.value() + if not seed: + seed = None + + noise = self.noise_box.value() + + return { + "sample_sizes": sample_sizes, + "dims_number": dims_number, + "blobs_number": blobs_number, + "seed": seed, + "noise": noise / 100, + } + + +class NormalDistributionClusteringOptions(ClusteringBlobsDataOptions): + def __init__(self): + super().__init__() + self.std_input = QLineEdit("5 1.2, 2.1 2.1") + self.layout.addRow( + QLabelWithTooltip( + "Standard deviation:", + "Values for different blobs should be separated via a comma.\n" + "Values for different dimensions in a single blob should be separated via a single space.\n" + "Format: blob1_dim1 blob1_dim2, blob2_dim1 blob2_dim2\n" + "If just one value is provided, instead of a series it will be used for every blob/dimension.", + ), + self.std_input, + ) + + def get_data(self) -> Dict: + data = super().get_data() + blobs_number = data["blobs_number"] + dims_number = data["dims_number"] + + standard_deviations = list( + map( + lambda std_per_blob: list(map(float, std_per_blob.strip().split(" "))), + self.std_input.text().split(","), + ) + ) + + provided_stds_number_blobs = len(standard_deviations) + if ( + provided_stds_number_blobs != 1 + and provided_stds_number_blobs != blobs_number + ): + raise ValueError( + "Provided configuration of standard deviations doesn't match set number of blobs" + ) + if provided_stds_number_blobs == 1: + standard_deviations *= blobs_number + + for blob_stds in standard_deviations: + provided_stds_number_dims = len(blob_stds) + if ( + provided_stds_number_dims != 1 + and provided_stds_number_dims != dims_number + ): + raise ValueError( + "Provided configuration of standard deviations doesn't match set number of dimensions" + ) + if provided_stds_number_dims == 1: + blob_stds *= dims_number + + return data | { + "dims_stds": standard_deviations, + } + + +class NoncentralFClusteringOptions(ClusteringBlobsDataOptions): + def __init__(self): + super().__init__() + self.df_num_input = QLineEdit("3 5") + self.layout.addRow( + QLabelWithTooltip( + "Numerator degrees of freedom:", + "Number of degrees of freedom of the Chi-squared distribution X.\n" + "Should be > 0.", + ), + self.df_num_input, + ) + + self.df_den_input = QLineEdit("20 15") + self.layout.addRow( + QLabelWithTooltip( + "Denominator degrees of freedom:", + "Number of degrees of freedom of the Chi-squared distribution Y.\n" + "Should be > 0.", + ), + self.df_den_input, + ) + + def get_data(self) -> Dict: + data = super().get_data() + blobs_number = data["blobs_number"] + + df_nums = list(map(float, self.df_num_input.text().split(" "))) + if len(df_nums) != 1 and len(df_nums) != blobs_number: + raise ValueError("Incorrect number of provided numerators") + if len(df_nums) == 1: + df_nums *= blobs_number + + if any(df_num <= 0 for df_num in df_nums): + raise ValueError("Degrees of freedom need to be positive numbers") + + df_dens = list(map(float, self.df_den_input.text().split(" "))) + if len(df_dens) != 1 and len(df_dens) != blobs_number: + raise ValueError("Incorrect number of provided denominators") + if len(df_dens) == 1: + df_dens *= blobs_number + + if any(df_den <= 0 for df_den in df_dens): + raise ValueError("Degrees of freedom need to be positive numbers") + + return data | { + "df_nums": df_nums, + "df_dens": df_dens, + } diff --git a/src/widgets/options_widgets/extra_trees_options.py b/src/widgets/options_widgets/extra_trees_options.py new file mode 100644 index 0000000..9cbee7a --- /dev/null +++ b/src/widgets/options_widgets/extra_trees_options.py @@ -0,0 +1,110 @@ +from math import sqrt + +from PyQt5.QtWidgets import QComboBox, QDoubleSpinBox, QSpinBox + +from widgets.components import QLabelWithTooltip +from widgets.options_widgets import AlgorithmOptions + + +class ExtraTreesOptions(AlgorithmOptions): + def __init__(self): + super().__init__() + + self.label_name_box = QComboBox() + self.layout.addRow( + QLabelWithTooltip("Column with labels:", "Column with target values."), + self.label_name_box, + ) + + self.forest_size_spinbox = QSpinBox() + self.forest_size_spinbox.setMinimum(1) + self.forest_size_spinbox.setMaximum(200) + self.forest_size_spinbox.setValue(50) + self.layout.addRow( + QLabelWithTooltip( + "Number of trees:", + "The number of the trees in the forest. The more, the better.", + ), + self.forest_size_spinbox, + ) + + self.features_number_spinbox = QSpinBox() + self.features_number_spinbox.setMinimum(1) + self.layout.addRow( + QLabelWithTooltip( + "Number of features to sample:", + "The number of features to consider when looking for the best split.\nRecommended tu use sqrt or log2 of number of columns.\nThe default value is sqrt.", + ), + self.features_number_spinbox, + ) + + self.min_child_number_spinbox = QSpinBox() + self.min_child_number_spinbox.setMinimum(1) + self.min_child_number_spinbox.setMaximum(1000) + self.min_child_number_spinbox.setValue(1) + self.layout.addRow( + QLabelWithTooltip( + "Minimum number of samples in child:", + "The minimum number of samples required to be at a leaf node.\nA split point will only be considered if it leaves at least this number of training samples in both branches.", + ), + self.min_child_number_spinbox, + ) + + self.max_depth_spinbox = QSpinBox() + self.max_depth_spinbox.setMinimum(1) + self.max_depth_spinbox.setMaximum(100) + self.max_depth_spinbox.setSpecialValueText("no limit") + self.max_depth_spinbox.setValue(1) + self.layout.addRow( + QLabelWithTooltip( + "Maximum depth:", + "The maximum depth of the tree.\nIf 'no limit', then nodes are expanded until all leaves are pure\nor until algorithm do not draw threshold fulfilling requirements\nabout number of samples in child and minimum metrics change.", + ), + self.max_depth_spinbox, + ) + + self.min_metrics_spinbox = QDoubleSpinBox() + self.min_metrics_spinbox.setMinimum(0) + self.min_metrics_spinbox.setMaximum(1) + self.min_metrics_spinbox.setValue(0) + self.min_metrics_spinbox.setSingleStep(0.01) + self.layout.addRow( + QLabelWithTooltip( + "Minimum metrics change:", + "A node will be split if this split induces a decrease of the impurity greater than or equal to this value.", + ), + self.min_metrics_spinbox, + ) + + self.metrics_type_box = QComboBox() + self.metrics_type_box.addItems(["gini", "entropy"]) + self.layout.addRow( + QLabelWithTooltip( + "Metrics type:", + "The function to measure the quality of a split.\n'gini': Gini impurity\n'entropy': based on Shannon information gain", + ), + self.metrics_type_box, + ) + + def get_data(self) -> dict: + max_depth = ( + self.max_depth_spinbox.value() + if self.max_depth_spinbox.value() > 1 + else None + ) + return { + "label_name": self.label_name_box.currentText(), + "forest_size": self.forest_size_spinbox.value(), + "features_number": self.features_number_spinbox.value(), + "min_child_number": self.min_child_number_spinbox.value(), + "max_depth": max_depth, + "min_metrics": self.min_metrics_spinbox.value(), + "metrics_type": self.metrics_type_box.currentText(), + } + + def set_values(self, columns: list): + self.label_name_box.clear() + self.label_name_box.addItems(columns) + max_number = len(columns) - 1 + self.features_number_spinbox.setMaximum(max_number) + self.features_number_spinbox.setValue(max(1, round(sqrt(max_number)))) diff --git a/src/widgets/options_widgets/gmm_options.py b/src/widgets/options_widgets/gmm_options.py new file mode 100644 index 0000000..b7bd194 --- /dev/null +++ b/src/widgets/options_widgets/gmm_options.py @@ -0,0 +1,53 @@ +from PyQt5.QtWidgets import QComboBox, QSpinBox + +from widgets.components import QLabelWithTooltip +from widgets.options_widgets import AlgorithmOptions + + +class GMMOptions(AlgorithmOptions): + def __init__(self): + super().__init__() + + self.num_clusters_spinbox = QSpinBox() + self.num_clusters_spinbox.setMinimum(2) + self.num_clusters_spinbox.setValue(3) + self.layout.addRow( + QLabelWithTooltip( + "Number of clusters:", + "Clusters number depends on the nature of the data.\nScatter plot (in the PREPROCESSING section)\ncan be helpful to enter correct number.", + ), + self.num_clusters_spinbox, + ) + + self.precision_box = QComboBox() + self.precision_box.addItems(["1e-10", "1e-8", "1e-6", "1e-4", "1e-2", "1"]) + self.layout.addRow( + QLabelWithTooltip( + "Precision:", + "The convergence threshold. Iterations will stop when the lower bound average gain is below this threshold.", + ), + self.precision_box, + ) + + self.num_steps_spinbox = QSpinBox() + self.num_steps_spinbox.setMinimum(0) + self.num_steps_spinbox.setMaximum(10000) + self.num_steps_spinbox.setSpecialValueText("no limit") + self.num_steps_spinbox.setValue(0) + self.layout.addRow( + QLabelWithTooltip( + "Maximum number of iterations:", + "Maximum number of iterations of the algorithm.\n'no limit' option exists.", + ), + self.num_steps_spinbox, + ) + + def get_data(self) -> dict: + return { + "num_clusters": self.num_clusters_spinbox.value(), + "eps": self.precision_box.currentText(), + "max_iterations": self.num_steps_spinbox.value() or None, + } + + def set_max_clusters(self, clusters_num): + self.num_clusters_spinbox.setMaximum(clusters_num) diff --git a/src/widgets/options_widgets/k_means_options.py b/src/widgets/options_widgets/k_means_options.py new file mode 100644 index 0000000..23ac298 --- /dev/null +++ b/src/widgets/options_widgets/k_means_options.py @@ -0,0 +1,80 @@ +from PyQt5.QtWidgets import QComboBox, QSpinBox + +from widgets.components import QLabelWithTooltip +from widgets.options_widgets import AlgorithmOptions + + +class KMeansOptions(AlgorithmOptions): + def __init__(self): + super().__init__() + + self.num_clusters_spinbox = QSpinBox() + self.num_clusters_spinbox.setMinimum(2) + self.num_clusters_spinbox.setValue(3) + self.layout.addRow( + QLabelWithTooltip( + "Number of clusters:", + "Clusters number depends on the nature of the data.\nScatter plot (in the PREPROCESSING section)\ncan be helpful to enter correct number.", + ), + self.num_clusters_spinbox, + ) + + self.start_type_box = QComboBox() + self.start_type_box.addItems(["random", "kmeans++"]) + self.start_type_box.setCurrentIndex(1) + self.layout.addRow( + QLabelWithTooltip( + "Type of initial solution:", + "Method for centroids initialization. Choose rows from dataset.\n'random': selects random rows from dataset\n'kmeans++': uses sampling based on an empirical probability distribution of the points’ contribution to the overall inertia\n You should use 'kmeans++', which speeds up convergence.", + ), + self.start_type_box, + ) + + self.metrics_spinbox = QSpinBox() + self.metrics_spinbox.setMinimum(1) + self.metrics_spinbox.setValue(2) + self.metrics_spinbox.setMaximum(6) + self.layout.addRow( + QLabelWithTooltip( + "Exponent in metrics:", + "Define value of p in the p-norm space. The default value is 2.", + ), + self.metrics_spinbox, + ) + + self.num_steps_spinbox = QSpinBox() + self.num_steps_spinbox.setMinimum(0) + self.num_steps_spinbox.setMaximum(1000) + self.num_steps_spinbox.setSpecialValueText("no limit") + self.num_steps_spinbox.setValue(0) + self.layout.addRow( + QLabelWithTooltip( + "Maximum number of iterations:", + "Maximum number of iterations of the k-means algorithm for a single run.\n'no limit' option exists, because k-means converges quickly.", + ), + self.num_steps_spinbox, + ) + + self.num_repeat_spinbox = QSpinBox() + self.num_repeat_spinbox.setMinimum(1) + self.num_repeat_spinbox.setMaximum(100) + self.num_repeat_spinbox.setValue(5) + self.layout.addRow( + QLabelWithTooltip( + "Number of repetitions:", + "Number of time the algorithm will be run with different initial centroids.\nThe final results will be the best output based on Dunn index.", + ), + self.num_repeat_spinbox, + ) + + def get_data(self) -> dict: + return { + "num_clusters": self.num_clusters_spinbox.value(), + "metrics": self.metrics_spinbox.value(), + "repeats": self.num_repeat_spinbox.value(), + "iterations": self.num_steps_spinbox.value() or None, + "init_type": self.start_type_box.currentText(), + } + + def set_max_clusters(self, clusters_num): + self.num_clusters_spinbox.setMaximum(clusters_num) diff --git a/src/widgets/results_widgets/__init__.py b/src/widgets/results_widgets/__init__.py new file mode 100644 index 0000000..08a684b --- /dev/null +++ b/src/widgets/results_widgets/__init__.py @@ -0,0 +1,5 @@ +from .algorithm_results import AlgorithmResultsWidget +from .a_priori_results import APrioriResultsWidget +from .extra_trees_results import ExtraTreesResultsWidget +from .gmm_results import GMMResultsWidget +from .k_means_results import KMeansResultsWidget diff --git a/src/widgets/results_widgets/a_priori_results.py b/src/widgets/results_widgets/a_priori_results.py new file mode 100644 index 0000000..b2d1cc2 --- /dev/null +++ b/src/widgets/results_widgets/a_priori_results.py @@ -0,0 +1,124 @@ +from typing import List + +import pandas as pd +from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QTableView, QVBoxLayout + +from visualization import APrioriGauge, APrioriGraphPlot, APrioriScatterPlot +from widgets import QtTable +from widgets.components import ParametersGroupBox +from widgets.results_widgets import AlgorithmResultsWidget + + +class APrioriResultsWidget(AlgorithmResultsWidget): + def __init__( + self, + data: pd.DataFrame, + frequent_sets: pd.DataFrame, + association_rules: pd.DataFrame, + transaction_sets: List[set], + options, + metrics_info, + ): + super().__init__(data, options, metrics_info) + + self.transaction_sets = transaction_sets + self.frequent_sets = frequent_sets.reset_index() + self.frequent_sets.rename(columns={"index": "frequent sets"}, inplace=True) + self.association_rules = association_rules.reset_index() + self.association_rules.rename( + columns={"index": "association rules"}, inplace=True + ) + self.columns = self.data.columns.values + self.min_support = self.options["min_support"] + self.min_confidence = self.options["min_confidence"] + + self.layout = QHBoxLayout(self) + + self.left_column = QVBoxLayout() + self.right_column = QVBoxLayout() + + # algorithm parameters + self.params_group = ParametersGroupBox(self.options) + + # algorithm metrics + if self.metrics_info: + self.metrics_group = ParametersGroupBox(self.metrics_info, "Metrics") + + # sets plots and charts + self.gauge_chart = APrioriGauge() + self.gauge_chart.layout().setContentsMargins(0, 0, 0, 0) + self.graph_plot = APrioriGraphPlot() + self.graph_plot.layout().setContentsMargins(0, 0, 0, 0) + self.graph_plot.show_placeholder() + + self.transactions_canvas = APrioriScatterPlot(transaction_sets) + self.transactions_canvas.reset() + + # frequent sets group + self.frequent_sets_result_group = QGroupBox() + self.frequent_sets_result_group_layout = QVBoxLayout( + self.frequent_sets_result_group + ) + self.frequent_sets_result_group.setTitle("Frequent sets result") + + # frequent sets table + self.frequent_sets_table = QTableView() + self.frequent_sets_table.setModel(QtTable(self.frequent_sets)) + self.frequent_sets_result_group_layout.addWidget(self.frequent_sets_table) + self.frequent_sets_table.clicked.connect(self.highlight_frequent_set) + + # association rules group + self.association_rules_group = QGroupBox() + self.association_rules_group_layout = QVBoxLayout(self.association_rules_group) + self.association_rules_group.setTitle("Association rules") + + # association rules table + self.association_rules_table = QTableView() + self.association_rules_table.setModel(QtTable(self.association_rules)) + self.association_rules_group_layout.addWidget(self.association_rules_table) + self.association_rules_table.clicked.connect(self.highlight_rule) + + self.right_column.addWidget(self.graph_plot, 1) + self.right_column.addWidget(self.gauge_chart, 1) + self.right_column.addWidget(self.transactions_canvas, 1) + + self.left_column.addWidget(self.params_group, 0) + if self.metrics_info: + self.left_column.addWidget(self.metrics_group, 0) + self.left_column.addWidget(self.frequent_sets_result_group, 1) + self.left_column.addWidget(self.association_rules_group, 1) + + self.layout.addLayout(self.left_column, 1) + self.layout.addLayout(self.right_column, 1) + + def highlight_frequent_set(self): + selected_set = ( + self.frequent_sets_table.selectionModel().selectedIndexes()[0].row() + ) + column_list = self.frequent_sets.iloc[selected_set]["frequent sets"][ + 1:-1 + ].split(", ") + self.graph_plot.plot_set(column_list) + self.gauge_chart.plot_value( + self.frequent_sets.iloc[selected_set]["support"], + self.min_support, + "support", + ) + self.transactions_canvas.plot_set(column_list) + + def highlight_rule(self): + selected_rule = ( + self.association_rules_table.selectionModel().selectedIndexes()[0].row() + ) + set_a, set_b = self.association_rules.iloc[selected_rule][ + "association rules" + ].split(" => ") + set_a = set_a[1:-1].split(", ") + set_b = set_b[1:-1].split(", ") + self.graph_plot.plot_rule(set_a, set_b) + self.gauge_chart.plot_value( + self.association_rules.iloc[selected_rule]["confidence"], + self.min_confidence, + "confidence", + ) + self.transactions_canvas.plot_rule(set_a, set_b) diff --git a/src/widgets/results_widgets/algorithm_results.py b/src/widgets/results_widgets/algorithm_results.py new file mode 100644 index 0000000..64df8c9 --- /dev/null +++ b/src/widgets/results_widgets/algorithm_results.py @@ -0,0 +1,19 @@ +from typing import Dict + +import pandas as pd +from PyQt5.QtWidgets import QWidget + + +class AlgorithmResultsWidget(QWidget): + """ + Widget with result visualization and summary + It is shown in 'Results' section + """ + + def __init__(self, data: pd.DataFrame, options: Dict, metrics_info: Dict): + super().__init__() + self.data = data + self.options = options + self.metrics_info = metrics_info + + self.layout = None diff --git a/src/widgets/results_widgets/extra_trees_results.py b/src/widgets/results_widgets/extra_trees_results.py new file mode 100644 index 0000000..8913f07 --- /dev/null +++ b/src/widgets/results_widgets/extra_trees_results.py @@ -0,0 +1,118 @@ +from functools import partial + +import pandas as pd +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import ( + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QScrollArea, + QVBoxLayout, +) + +from widgets.components import ParametersGroupBox +from widgets.results_widgets import AlgorithmResultsWidget + + +class ExtraTreesResultsWidget(AlgorithmResultsWidget): + def __init__( + self, data, predict, configs, feature_importance, options, metrics_info + ): + super().__init__(data, options, metrics_info) + + self.predict = predict + self.configs = configs + self.feature_importance = feature_importance + self.first_prediction = True + + self.layout = QHBoxLayout(self) + + # algorithm parameters and metrics + self.params_group = ParametersGroupBox(self.options) + if self.metrics_info: + self.metrics_group = ParametersGroupBox(self.metrics_info, "Metrics") + self.params_metric_layout = QVBoxLayout() + self.params_metric_layout.addWidget(self.params_group) + self.params_metric_layout.addWidget(self.metrics_group) + self.layout.addLayout(self.params_metric_layout) + else: + self.layout.addWidget(self.params_group) + + # feature importance + self.feature_importance_group = QGroupBox() + self.feature_importance_group.setTitle("Feature importance") + self.feature_importance_layout = QFormLayout(self.feature_importance_group) + + for feature, value in self.feature_importance.items(): + self.feature_importance_layout.addRow( + QLabel(feature), QLabel(str(round(value, 2))) + ) + + self.layout.addWidget(self.feature_importance_group) + + # prediction + self.prediction_group = QGroupBox() + self.prediction_group.setTitle("Prediction") + self.prediction_layout = QFormLayout(self.prediction_group) + + self.prediction_input = {} + self.input_type = {} + for label, t in configs: + self.prediction_input[label] = QLineEdit() + self.input_type[label] = t + self.prediction_layout.addRow( + QLabel(f"{label} [{str(t)}]"), self.prediction_input[label] + ) + self.prediction_button = QPushButton("Predict") + self.prediction_button.clicked.connect(partial(self.click_listener, "predict")) + self.prediction_layout.addRow(self.prediction_button) + + self.scroll_box = QGroupBox() + self.results_layout = QFormLayout(self.scroll_box) + self.scroll = QScrollArea() + self.scroll.setWidget(self.scroll_box) + self.scroll.setWidgetResizable(True) + self.prediction_layout.addRow(self.scroll) + + self.layout.addWidget(self.prediction_group) + + def click_listener(self, button_type: str): + match button_type: + case "predict": + try: + input_dict = { + label: [field.text()] + for label, field in self.prediction_input.items() + } + input_data = pd.DataFrame(input_dict) + input_data.astype(self.input_type, copy=False) + except ValueError: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("Entered data are not in valid types.") + error.setWindowTitle("Error") + error.exec_() + return + if self.first_prediction: + self.results_layout.addRow(QLabel("Results:")) + self.first_prediction = False + result = self.predict(input_data.iloc[0]) + result_label = QLabel( + "\n".join( + [ + f"{label}: {100 * value:.2f}%" + for label, value in result.items() + ] + ) + ) + result_label.setAlignment(Qt.AlignRight) + input_label = QLabel( + "\n".join( + [f"{label}: {value}" for label, [value] in input_dict.items()] + ) + ) + self.results_layout.addRow(input_label, result_label) diff --git a/src/widgets/results_widgets/gmm_results.py b/src/widgets/results_widgets/gmm_results.py new file mode 100644 index 0000000..8f7664b --- /dev/null +++ b/src/widgets/results_widgets/gmm_results.py @@ -0,0 +1,182 @@ +import numpy as np +import pandas as pd +from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QVBoxLayout + +from visualization import ClusteringCanvas +from widgets.components import ClustersTable, ParametersGroupBox, SamplesColumnsChoice +from widgets.results_widgets import AlgorithmResultsWidget + + +class GMMResultsWidget(AlgorithmResultsWidget): + def __init__(self, df, labels, mean, sigma, options, metrics_info): + super().__init__(df.select_dtypes(include=["number"]), options, metrics_info) + + self.labels = labels + self.max_label = np.amax(self.labels) + 1 + self.mean = mean + self.sigma = sigma + self.columns = self.mean.columns + self.layout = QHBoxLayout(self) + + # algorithm parameters and metrics + self.params_group = ParametersGroupBox(self.options) + if self.metrics_info: + self.metrics_group = ParametersGroupBox(self.metrics_info, "Metrics") + self.params_metric_layout = QVBoxLayout() + self.params_metric_layout.addWidget(self.params_group) + self.params_metric_layout.addWidget(self.metrics_group) + self.layout.addLayout(self.params_metric_layout) + else: + self.layout.addWidget(self.params_group) + + # clustering result group + self.clustering_result_group = QGroupBox() + self.clustering_group_layout = QVBoxLayout(self.clustering_result_group) + self.clustering_result_group.setTitle("Clustering result") + + # samples columns choice + self.parameters_widget = SamplesColumnsChoice(self.columns, len(self.data)) + self.parameters_widget.samples_columns_changed.connect(self.update_plot) + self.parameters_widget.samples_columns_changed.connect(self.update_cluster_plot) + self.clustering_group_layout.addWidget(self.parameters_widget) + + # plot + self.results_canvas = ClusteringCanvas(False) + self.clustering_group_layout.addWidget(self.results_canvas, 1) + + self.layout.addWidget(self.clustering_result_group, 1) + + # cluster details + self.clusters_group = QGroupBox() + self.clusters_group_layout = QVBoxLayout(self.clusters_group) + self.clusters_group.setTitle("Clusters") + + self.clusters_table = ClustersTable( + self.data, self.labels, self.mean, len(self.columns) + ) + self.clusters_table.table_changed.connect(self.update_cluster_plot) + self.clusters_group_layout.addWidget(self.clusters_table, 1) + + self.clusters_canvas = ClusteringCanvas(False) + + self.clusters_group_layout.addWidget(self.clusters_canvas, 1) + self.layout.addWidget(self.clusters_group, 1) + + self.update_plot() + self.update_cluster_plot() + + def update_plot(self): + samples_data = self.data.iloc[self.parameters_widget.samples] + x = samples_data[self.parameters_widget.ox] + y = samples_data[self.parameters_widget.oy] + min_x = self.data[self.parameters_widget.ox].min() + max_x = self.data[self.parameters_widget.ox].max() + min_y = self.data[self.parameters_widget.oy].min() + max_y = self.data[self.parameters_widget.oy].max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + labels = [self.labels[sample] for sample in self.parameters_widget.samples] + self.results_canvas.clusters_plot( + x, + y, + list(self.columns), + self.mean.values, + self.sigma, + labels, + self.max_label, + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) + + def update_cluster_plot(self): + if self.clusters_table.selected_cluster is not None: + indexes = [ + i + for i in range(len(self.labels)) + if self.labels[i] == self.clusters_table.selected_cluster + ] + x = self.data.iloc[indexes][self.parameters_widget.ox] + y = self.data.iloc[indexes][self.parameters_widget.oy] + min_x = x.min() + max_x = x.max() + min_y = y.min() + max_y = y.max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + x_means = self.mean[self.parameters_widget.ox] + y_means = self.mean[self.parameters_widget.oy] + x_index, y_index = [ + self.columns.get_loc(self.parameters_widget.ox), + self.columns.get_loc(self.parameters_widget.oy), + ] + mean = [ + x_means.iloc[self.clusters_table.selected_cluster], + y_means.iloc[self.clusters_table.selected_cluster], + ] + sigma_helper = self.sigma[self.clusters_table.selected_cluster] + sigma = [ + [sigma_helper[x_index][x_index], sigma_helper[x_index][y_index]], + [sigma_helper[y_index][x_index], sigma_helper[y_index][y_index]], + ] + + self.clusters_canvas.chosen_cluster_plot( + x, + y, + mean, + sigma, + self.clusters_table.selected_cluster, + len(x_means), + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) + else: + x = self.data[self.parameters_widget.ox] + y = self.data[self.parameters_widget.oy] + min_x = x.min() + max_x = x.max() + min_y = y.min() + max_y = y.max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + x_means = self.mean[self.parameters_widget.ox] + y_means = self.mean[self.parameters_widget.oy] + x_index, y_index = [ + self.columns.get_loc(self.parameters_widget.ox), + self.columns.get_loc(self.parameters_widget.oy), + ] + means = [x_means, y_means] + sigmas = [] + for i in range(len(self.sigma)): + sigma_helper = self.sigma[i] + sigmas.append( + [ + [ + sigma_helper[x_index][x_index], + sigma_helper[x_index][y_index], + ], + [ + sigma_helper[y_index][x_index], + sigma_helper[y_index][y_index], + ], + ] + ) + + self.clusters_canvas.clusters_means_plot( + means, + sigmas, + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) diff --git a/src/widgets/results_widgets/k_means_results.py b/src/widgets/results_widgets/k_means_results.py new file mode 100644 index 0000000..3380f2c --- /dev/null +++ b/src/widgets/results_widgets/k_means_results.py @@ -0,0 +1,217 @@ +import pandas as pd +from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QVBoxLayout + +from visualization import ClusteringCanvas +from widgets.components import ClustersTable, ParametersGroupBox, SamplesColumnsChoice +from widgets.results_widgets import AlgorithmResultsWidget + + +class KMeansResultsWidget(AlgorithmResultsWidget): + def __init__(self, data, labels, centroids, options, metrics_info): + super().__init__(data, options, metrics_info) + + self.labels = labels + self.centroids = centroids + + columns = self.data.select_dtypes(include=["number"]).columns + + self.layout = QHBoxLayout(self) + + # algorithm parameters and metrics + self.params_group = ParametersGroupBox(self.options) + if self.metrics_info: + self.metrics_group = ParametersGroupBox(self.metrics_info, "Metrics") + self.params_metric_layout = QVBoxLayout() + self.params_metric_layout.addWidget(self.params_group) + self.params_metric_layout.addWidget(self.metrics_group) + self.layout.addLayout(self.params_metric_layout) + else: + self.layout.addWidget(self.params_group) + + # clustering result group + self.clustering_result_group = QGroupBox() + self.clustering_group_layout = QVBoxLayout(self.clustering_result_group) + self.clustering_result_group.setTitle("Clustering result") + + # samples columns choice + self.parameters_widget = SamplesColumnsChoice(columns, len(self.data)) + self.parameters_widget.samples_columns_changed.connect(self.update_plot) + self.parameters_widget.samples_columns_changed.connect(self.update_cluster_plot) + self.clustering_group_layout.addWidget(self.parameters_widget) + + # plot + self.results_canvas = ClusteringCanvas(False) + self.clustering_group_layout.addWidget(self.results_canvas, 1) + + self.layout.addWidget(self.clustering_result_group, 1) + + # centroids + self.clusters_group = QGroupBox() + self.clusters_group_layout = QVBoxLayout(self.clusters_group) + self.clusters_group.setTitle("Clusters") + + self.clusters_table = ClustersTable( + self.data, self.labels, self.centroids, len(columns) + ) + self.clusters_table.table_changed.connect(self.update_cluster_plot) + self.clusters_group_layout.addWidget(self.clusters_table, 1) + + self.clusters_canvas = ClusteringCanvas(False) + + self.clusters_group_layout.addWidget(self.clusters_canvas, 1) + self.layout.addWidget(self.clusters_group, 1) + + self.update_plot() + self.update_cluster_plot() + + def update_plot(self): + samples_data = self.data.iloc[self.parameters_widget.samples] + x = samples_data[self.parameters_widget.ox] + y = samples_data[self.parameters_widget.oy] + min_x = self.data[self.parameters_widget.ox].min() + max_x = self.data[self.parameters_widget.ox].max() + min_y = self.data[self.parameters_widget.oy].min() + max_y = self.data[self.parameters_widget.oy].max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + labels = [self.labels[sample] for sample in self.parameters_widget.samples] + if self.parameters_widget.ox in self.centroids.columns: + x_centroids = self.centroids[self.parameters_widget.ox] + else: + x_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.ox + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + if self.parameters_widget.oy in self.centroids.columns: + y_centroids = self.centroids[self.parameters_widget.oy] + else: + y_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.oy + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + + self.results_canvas.all_plot( + x, + y, + x_centroids, + y_centroids, + labels, + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) + + def update_cluster_plot(self): + if self.clusters_table.selected_cluster is not None: + indexes = [ + i + for i in range(len(self.labels)) + if self.labels[i] == self.clusters_table.selected_cluster + ] + x = self.data.iloc[indexes][self.parameters_widget.ox] + y = self.data.iloc[indexes][self.parameters_widget.oy] + min_x = x.min() + max_x = x.max() + min_y = y.min() + max_y = y.max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + if self.parameters_widget.ox in self.centroids.columns: + x_centroids = self.centroids[self.parameters_widget.ox] + else: + x_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.ox + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + if self.parameters_widget.oy in self.centroids.columns: + y_centroids = self.centroids[self.parameters_widget.oy] + else: + y_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.oy + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + + self.clusters_canvas.chosen_centroid_plot( + x, + y, + None, + None, + None, + None, + x_centroids.iloc[self.clusters_table.selected_cluster], + y_centroids.iloc[self.clusters_table.selected_cluster], + self.clusters_table.selected_cluster, + len(x_centroids), + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) + else: + x = self.data[self.parameters_widget.ox] + y = self.data[self.parameters_widget.oy] + min_x = x.min() + max_x = x.max() + min_y = y.min() + max_y = y.max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + if self.parameters_widget.ox in self.centroids.columns: + x_centroids = self.centroids[self.parameters_widget.ox] + else: + x_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.ox + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + if self.parameters_widget.oy in self.centroids.columns: + y_centroids = self.centroids[self.parameters_widget.oy] + else: + y_centroids = pd.Series( + [ + self.data.iloc[self.labels == label][ + self.parameters_widget.oy + ].mean() + for label in range(max(self.labels) + 1) + ] + ) + + self.clusters_canvas.new_centroids_plot( + None, + None, + x_centroids, + y_centroids, + self.parameters_widget.ox, + self.parameters_widget.oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + ) diff --git a/src/widgets/rotated_button.py b/src/widgets/rotated_button.py new file mode 100644 index 0000000..c6da55a --- /dev/null +++ b/src/widgets/rotated_button.py @@ -0,0 +1,56 @@ +from PyQt5.QtWidgets import QPushButton, QStyle, QStyleOptionButton, QStylePainter + + +class RotatedButton(QPushButton): + def __init__(self, parent, orientation="east"): + super().__init__(parent) + self.orientation = orientation + + def paintEvent(self, event): + painter = QStylePainter(self) + if self.orientation == "east": + painter.rotate(270) + painter.translate(-1 * self.height(), 0) + if self.orientation == "west": + painter.rotate(90) + painter.translate(0, -1 * self.width()) + painter.drawControl(QStyle.CE_PushButton, self.get_style_options()) + + def minimumSizeHint(self): + size = super(RotatedButton, self).minimumSizeHint() + size.transpose() + return size + + def sizeHint(self): + size = super(RotatedButton, self).sizeHint() + size.transpose() + return size + + def get_style_options(self): + options = QStyleOptionButton() + options.initFrom(self) + size = options.rect.size() + size.transpose() + options.rect.setSize(size) + options.features = QStyleOptionButton.None_ + + if self.isFlat(): + options.features |= QStyleOptionButton.Flat + if self.menu(): + options.features |= QStyleOptionButton.HasMenu + if self.autoDefault() or self.isDefault(): + options.features |= QStyleOptionButton.AutoDefaultButton + if self.isDefault(): + options.features |= QStyleOptionButton.DefaultButton + if self.isDown() or (self.menu() and self.menu().isVisible()): + options.state |= QStyle.State_Sunken + if self.isChecked(): + options.state |= QStyle.State_On + if not self.isFlat() and not self.isDown(): + options.state |= QStyle.State_Raised + + options.text = self.text() + options.icon = self.icon() + options.iconSize = self.iconSize() + + return options diff --git a/src/widgets/steps_widgets/__init__.py b/src/widgets/steps_widgets/__init__.py new file mode 100644 index 0000000..a577c1e --- /dev/null +++ b/src/widgets/steps_widgets/__init__.py @@ -0,0 +1,5 @@ +from .algorithm_vis import AlgorithmStepsVisualization +from .a_priori_vis import APrioriStepsVisualization +from .extra_trees_vis import ExtraTreesStepsVisualization +from .gmm_vis import GMMStepsVisualization +from .k_means_vis import KMeansStepsVisualization diff --git a/src/widgets/steps_widgets/a_priori_vis.py b/src/widgets/steps_widgets/a_priori_vis.py new file mode 100644 index 0000000..cc9f32f --- /dev/null +++ b/src/widgets/steps_widgets/a_priori_vis.py @@ -0,0 +1,350 @@ +from functools import partial +from typing import List + +import pandas as pd +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import ( + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QPushButton, + QScrollArea, + QSizePolicy, + QSpinBox, + QTableView, + QVBoxLayout, +) + +from algorithms.associations import APrioriPartLabel +from utils import AutomateSteps, format_set +from visualization import APrioriGauge, APrioriGraphPlot +from widgets import QtTable +from widgets.steps_widgets import AlgorithmStepsVisualization + + +class APrioriStepsVisualization(AlgorithmStepsVisualization): + def __init__( + self, data: pd.DataFrame, algorithms_steps: List[dict], is_animation: bool + ): + super().__init__(data, algorithms_steps, is_animation) + + self.is_running = False + self.animation = None + self.max_step = len(algorithms_steps) - 1 + self.current_step = 0 + + self.setObjectName("a_priori_steps_visualization") + + # layout + self.layout = QVBoxLayout(self) + self.bottom_row_layout = QHBoxLayout() + self.bottom_row_layout.setSpacing(35) + + # visualization layout + self.visualization_box = QGroupBox() + self.visualization_box.setTitle("APriori step by step") + self.visualization_box_layout = QVBoxLayout(self.visualization_box) + + # visualization charts and plots + self.sets_table = QTableView() + self.gauge_chart = APrioriGauge() + self.gauge_chart.layout().setContentsMargins(0, 0, 0, 0) + self.graph_plot = APrioriGraphPlot() + self.graph_plot.layout().setContentsMargins(0, 0, 0, 0) + self.algorithm_part_label = QLabel() + + self.step_vis_layout = QHBoxLayout() + self.step_charts_layout = QVBoxLayout() + + self.step_charts_layout.addWidget(self.graph_plot, 1) + self.step_charts_layout.addWidget(self.gauge_chart, 1) + + self.step_vis_layout.addWidget(self.sets_table, 1) + self.step_vis_layout.addLayout(self.step_charts_layout, 1) + + self.visualization_box_layout.addWidget( + self.algorithm_part_label, 0, alignment=Qt.AlignCenter + ) + self.visualization_box_layout.addLayout(self.step_vis_layout, 3) + self.visualization_box_layout.addWidget(self._render_description(), 2) + + # controls + self._render_control_ui() + + self.layout.addWidget(self.visualization_box) + self.update_plot(0) + + def _render_control_ui(self): + if self.is_animation: + self.automat = AutomateSteps( + lambda: self.change_step(1), + lambda: self.change_step(-1 * self.current_step), + ) + self.is_running = False + + # animation + self.animation_box = QGroupBox() + self.animation_box.setFixedWidth(250) + self.animation_box.setTitle("Animation") + self.animation_box_layout = QFormLayout(self.animation_box) + + self.restart_button = QPushButton("Restart") + self.restart_button.clicked.connect(partial(self.click_listener, "restart")) + self.run_button = QPushButton("Start animation") + self.run_button.clicked.connect(partial(self.click_listener, "run")) + self.interval_box = QSpinBox() + self.interval_box.setMinimum(500) + self.interval_box.setMaximum(3000) + self.interval_box.setValue(1000) + self.interval_box.setSingleStep(20) + + self.animation_box_layout.addRow( + QLabel("Interval time [ms]:"), self.interval_box + ) + self.animation_box_layout.addRow(self.restart_button) + self.animation_box_layout.addRow(self.run_button) + + self.step_label = QLabel("STEP: {}".format(self.current_step)) + self.animation_box_layout.addWidget(self.step_label) + + self.description_group_box_layout.addWidget(self.animation_box, 0) + + else: + self.visualization_box_layout.addStretch() + + # control buttons + self.control_buttons_layout = QHBoxLayout() + self.left_box = QSpinBox() + self.left_box.setMinimum(1) + self.right_box = QSpinBox() + self.right_box.setMinimum(1) + self.left_button = QPushButton("PREV") + self.left_button.clicked.connect(partial(self.click_listener, "prev")) + self.right_button = QPushButton("NEXT") + self.right_button.clicked.connect(partial(self.click_listener, "next")) + self.step_label = QLabel("STEP: {}".format(self.current_step)) + + self.next_part_button = QPushButton("NEXT PART") + self.next_part_button.clicked.connect( + partial(self.click_listener, "next part") + ) + self.prev_part_button = QPushButton("PREV PART") + self.prev_part_button.clicked.connect( + partial(self.click_listener, "prev part") + ) + + self.control_buttons_layout.addWidget(self.prev_part_button) + self.control_buttons_layout.addWidget(self.left_button) + self.control_buttons_layout.addWidget(self.left_box) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.step_label) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.right_box) + self.control_buttons_layout.addWidget(self.right_button) + self.control_buttons_layout.addWidget(self.next_part_button) + + self.visualization_box_layout.addLayout(self.control_buttons_layout, 0) + + def _render_description(self): + description = "Apriori algorithm - steps visualization" + + self.description_label = QLabel(description) + self.description_label.setWordWrap(True) + self.description_label.setSizePolicy( + QSizePolicy.Expanding, QSizePolicy.Expanding + ) + + self.description_group_box = QGroupBox() + self.description_group_box.setTitle("Description") + self.description_group_box_layout = QHBoxLayout(self.description_group_box) + + self.scroll_box = QGroupBox() + self.scroll_box_layout = QFormLayout(self.scroll_box) + self.scroll = QScrollArea() + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.scroll.setWidget(self.scroll_box) + self.scroll.setWidgetResizable(True) + self.scroll.setMinimumHeight(26) + self.description_group_box_layout.addWidget(self.scroll) + + self.scroll_box_layout.addWidget(self.description_label) + return self.description_group_box + + def click_listener(self, button_type: str): + match button_type: + case "prev": + self.change_step(-self.left_box.value()) + case "next": + self.change_step(self.right_box.value()) + case "restart": + self.is_running = False + self.interval_box.setEnabled(True) + self.run_button.setEnabled(True) + self.automat.restart() + self.run_button.setText("Start animation") + case "run": + self.is_running = not self.is_running + if self.is_running: + self.restart_button.setEnabled(False) + self.interval_box.setEnabled(False) + self.automat.set_time(self.interval_box.value()) + self.automat.resume() + self.run_button.setText("Stop animation") + else: + self.automat.pause() + self.run_button.setText("Start animation") + self.restart_button.setEnabled(True) + case "next part": + step_delta = 1 + current_part = self.algorithms_steps[self.current_step]["part"] + while ( + self.current_step + step_delta <= self.max_step + and self.algorithms_steps[self.current_step + step_delta]["part"] + == current_part + ): + step_delta += 1 + + self.change_step(step_delta) + case "prev part": + step_delta = -1 + current_part = self.algorithms_steps[self.current_step]["part"] + while ( + self.current_step + step_delta >= 0 + and self.algorithms_steps[self.current_step + step_delta]["part"] + == current_part + ): + step_delta -= 1 + + current_part = self.algorithms_steps[self.current_step + step_delta][ + "part" + ] + while ( + self.current_step + step_delta >= 0 + and self.algorithms_steps[self.current_step + step_delta]["part"] + == current_part + ): + step_delta -= 1 + + step_delta += 1 + self.change_step(step_delta) + + def change_step(self, change: int): + new_step = max(0, min(self.max_step, self.current_step + change)) + if new_step == self.current_step: + return + self.current_step = new_step + self.step_label.setText("STEP: {}".format(self.current_step)) + self.update_plot(self.current_step) + self.step_label.update() + + def update_plot(self, step: int): + if step == self.max_step: + if self.is_animation: + self.run_button.setText("Start animation") + self.run_button.setEnabled(False) + self.is_running = False + self.change_enabled_buttons(True) + + self.current_step = step + step_dict = self.algorithms_steps[self.current_step] + df = step_dict["data_frame"] + if df is not None: + df = df.reset_index() + if "confidence" in df: + df.rename(columns={"index": "association rules"}, inplace=True) + else: + df.rename(columns={"index": "frequent sets"}, inplace=True) + + self.sets_table.setModel(QtTable(df) if df is not None else None) + self.algorithm_part_label.setText(step_dict["part"].value) + self.gauge_chart.reset() + self.graph_plot.reset() + + description = "" + match step_dict["part"]: + case APrioriPartLabel.CALCULATE_SUPPORT: + description = ( + "Checking whether set: {} is frequent.\n" + "Its support equals {} \nIt is {}a frequent set.".format( + format_set(step_dict["set"]), + round(step_dict["support"], 3), + "not " + if step_dict["support"] < step_dict["min_support"] + else "", + ) + ) + self.gauge_chart.plot_value( + round(step_dict["support"], 3), step_dict["min_support"], "support" + ) + self.graph_plot.plot_set(step_dict["set"]) + case APrioriPartLabel.FILTER_BY_SUPPORT: + description = ( + "We have found that the following sets " + "are frequent:\n{}, whereas those are not:\n{}".format( + "\n".join(map(format_set, step_dict["frequent_sets"])), + "\n".join(map(format_set, step_dict["infrequent_sets"])), + ) + ) + case APrioriPartLabel.SAVE_K_SETS: + description = "We have found all frequent sets for k={}".format( + step_dict["k"] + ) + case APrioriPartLabel.SAVE_RULES: + description = ( + "We have found all association rules for " + "specified minimum confidence and support." + ) + case APrioriPartLabel.JOIN_AND_PRUNE: + description = ( + "We are joining sets: {} and {}, " + "then analyzing resulting set: {}. ".format( + format_set(step_dict["set_1"]), + format_set(step_dict["set_2"]), + format_set(step_dict["new_set"]), + ) + ) + + if step_dict["infrequent_subset"] is None: + description += ( + "This set does not contain any infrequent " + "subsets. It might be frequent itself." + ) + else: + description += ( + "This set contains an infrequent subset: " + "{}. Therefore it is not frequent itself.".format( + step_dict["infrequent_subset"] + ) + ) + self.graph_plot.plot_set(step_dict["new_set"]) + + case APrioriPartLabel.GENERATE_RULES: + description = ( + "We divide frequent set into A = {} and B = {}. " + "The confidence of the rule A => B " + "equals {}\n\n".format( + format_set(step_dict["set_a"]), + format_set(step_dict["set_b"]), + round(step_dict["confidence"], 3), + ) + ) + + if step_dict["confidence"] >= step_dict["min_confidence"]: + description += "We have found a new association rule." + else: + description += ( + "It is not enough to consider it a " + "valid association rule for our data." + ) + + self.graph_plot.plot_rule(step_dict["set_a"], step_dict["set_b"]) + self.gauge_chart.plot_value( + round(step_dict["confidence"], 3), + step_dict["min_confidence"], + "confidence", + ) + + self.description_label.setText(description) diff --git a/src/widgets/steps_widgets/algorithm_vis.py b/src/widgets/steps_widgets/algorithm_vis.py new file mode 100644 index 0000000..24c7248 --- /dev/null +++ b/src/widgets/steps_widgets/algorithm_vis.py @@ -0,0 +1,21 @@ +from typing import List + +import pandas as pd +from PyQt5.QtWidgets import QWidget + + +class AlgorithmStepsVisualization(QWidget): + """ + Widget with visualization of algorithm creation + It is shown in 'Algorithm Run' section + Works in two mode: with animation and without animation + """ + + def __init__(self, data: pd.DataFrame, algorithms_steps: List, is_animation: bool): + super().__init__() + + self.data = data + self.algorithms_steps = algorithms_steps + self.is_animation = is_animation + + self.layout = None diff --git a/src/widgets/steps_widgets/extra_trees_vis.py b/src/widgets/steps_widgets/extra_trees_vis.py new file mode 100644 index 0000000..c6d0e21 --- /dev/null +++ b/src/widgets/steps_widgets/extra_trees_vis.py @@ -0,0 +1,367 @@ +from functools import partial +from random import randint +from typing import Dict, List, Optional, Tuple + +import graphviz +import pandas as pd +import pygraphviz as pgv +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QFont, QPixmap +from PyQt5.QtWidgets import ( + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QPushButton, + QSpinBox, + QTableView, + QVBoxLayout, + QWidget, +) +from QGraphViz.DotParser import Graph, GraphType +from QGraphViz.Engines import Dot +from QGraphViz.QGraphViz import QGraphViz + +from utils import AutomateSteps, QImage, QtImageViewer +from widgets import QtTable +from widgets.steps_widgets import AlgorithmStepsVisualization + + +class StepWidget(QWidget): + def __init__(self, graph: QWidget, info: Optional[List] = None): + super().__init__() + self.graph = graph + self.info = info + self.layout = QHBoxLayout(self) + + self.info_group = QGroupBox("Description") + self.info_group.setFixedWidth(300) + self.info_group_layout = QVBoxLayout(self.info_group) + if self.info: + self.table_info = QTableView() + data = pd.DataFrame( + self.info, columns=["Column name", "Pivot", "Metric changes"] + ) + data.sort_values(by=["Metric changes"], ascending=False, inplace=True) + self.table_info.setModel(QtTable(data)) + description = ( + "Algorithm is looking for new pivot for dark blue node. " + "It draws n features from columns and for each draws a split point (n is provided in options section). " + "In the next step the metric is calculated for each split. " + "The splits are listed in table and sorted by metrics (nan in value of metrics means that this pivot is forbidden). " + "The best split is chosen." + ) + else: + description = ( + "There are result of division on visualization. " + "Under the orange edge algorithm considers samples, which not fulfills condition in parent node. " + "Under blue edge samples fulfills this condition. " + "The oval nodes are leaves. The rectangular nodes will be divided in next steps." + ) + self.description_label = QLabel(description) + self.description_label.setWordWrap(True) + self.info_group_layout.addWidget(self.description_label) + if self.info: + self.info_group_layout.addWidget(self.table_info) + self.info_group_layout.addStretch(1) + self.layout.addWidget(self.info_group) + self.layout.addWidget(self.graph) + + +class TreeStepsVisualization(QWidget): + def __init__( + self, widget: QWidget, node_info: Dict, dot_steps: List[str], is_animation: bool + ): + super().__init__() + self.setWindowTitle("Tree creation steps") + self.parent = widget + self.node_info = node_info + self.dot_steps = dot_steps + self.max_steps = len(self.dot_steps) + self.is_animation = is_animation + self.current_step = 0 + + self.layout = QVBoxLayout(self) + + self.step_group = QGroupBox() + self.step_group_layout = QVBoxLayout(self.step_group) + self.step_group_layout.addWidget(self.create_step_graph(self.current_step)) + self.layout.addWidget(self.step_group, 1) + + if self.is_animation: + self.automat = AutomateSteps( + lambda: self.change_step(1), + lambda: self.change_step(-1 * self.current_step), + ) + self.is_running = False + + # animation + self.animation_box = QGroupBox() + self.animation_box.setFixedWidth(250) + self.animation_box.setTitle("Animation") + self.animation_box_layout = QFormLayout(self.animation_box) + + self.restart_button = QPushButton("Restart") + self.restart_button.clicked.connect(partial(self.click_listener, "restart")) + self.run_button = QPushButton("Start animation") + self.run_button.clicked.connect(partial(self.click_listener, "run")) + self.interval_box = QSpinBox() + self.interval_box.setMinimum(500) + self.interval_box.setMaximum(3000) + self.interval_box.setValue(1000) + self.interval_box.setSingleStep(20) + + self.animation_box_layout.addRow( + QLabel("Interval time [ms]:"), self.interval_box + ) + self.animation_box_layout.addRow(self.restart_button) + self.animation_box_layout.addRow(self.run_button) + + self.layout.addWidget(self.animation_box, 0) + + if not self.is_animation: + # control buttons + self.control_buttons_layout = QHBoxLayout() + self.left_box = QSpinBox() + self.left_box.setMinimum(1) + self.right_box = QSpinBox() + self.right_box.setMinimum(1) + self.left_button = QPushButton("PREV") + self.left_button.clicked.connect(partial(self.click_listener, "prev")) + self.right_button = QPushButton("NEXT") + self.right_button.clicked.connect(partial(self.click_listener, "next")) + self.step_label = QLabel("STEP: {}".format(self.current_step)) + self.control_buttons_layout.addWidget(self.left_button) + self.control_buttons_layout.addWidget(self.left_box) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.step_label) + self.control_buttons_layout.addStretch() + self.control_buttons_layout.addWidget(self.right_box) + self.control_buttons_layout.addWidget(self.right_button) + + self.layout.addLayout(self.control_buttons_layout, 0) + else: + self.step_label = QLabel("STEP: {}".format(self.current_step)) + self.layout.addWidget(self.step_label, 0, alignment=Qt.AlignCenter) + + def create_step_graph(self, step_num: int): + if step_num % 2 == 0: + info = self.node_info[step_num] + else: + info = None + graph = graphviz.Source(self.dot_steps[step_num]) + graph.render("tmp/graph", format="png") + if self.is_animation: + image = QImage() + image.setPixmap(QPixmap("tmp/graph.png")) + else: + image = QtImageViewer() + image.open("tmp/graph.png") + return StepWidget(image, info) + + def update_step(self): + for i in reversed(range(self.step_group_layout.count())): + self.step_group_layout.itemAt(i).widget().setParent(None) + self.step_group_layout.addWidget(self.create_step_graph(self.current_step)) + + self.step_label.setText("STEP: {}".format(self.current_step)) + self.step_label.update() + + def click_listener(self, button_type: str): + match button_type: + case "prev": + num = self.left_box.value() + self.change_step(-1 * num) + case "next": + num = self.right_box.value() + self.change_step(num) + case "restart": + self.is_running = False + self.interval_box.setEnabled(True) + self.run_button.setEnabled(True) + self.automat.restart() + self.run_button.setText("Start animation") + case "run": + self.is_running = not self.is_running + if self.is_running: + self.restart_button.setEnabled(False) + self.interval_box.setEnabled(False) + self.automat.set_time(self.interval_box.value()) + self.automat.resume() + self.run_button.setText("Stop animation") + else: + self.automat.pause() + self.run_button.setText("Start animation") + self.restart_button.setEnabled(True) + + def change_step(self, delta: int): + new_step = delta + self.current_step + new_step = max(0, min(new_step, self.max_steps - 1)) + if new_step == self.current_step: + return + self.current_step = new_step + self.update_step() + if self.current_step == self.max_steps - 1 and self.is_animation: + self.automat.pause() + self.run_button.setText("Start animation") + self.run_button.setEnabled(False) + self.restart_button.setEnabled(True) + + +class ExtraTreesStepsVisualization(AlgorithmStepsVisualization): + def __init__( + self, + data: pd.DataFrame, + algorithms_steps: List[Tuple[str, Dict, List[str]]], + is_animation: bool, + ): + super().__init__(data, algorithms_steps, is_animation) + + self.steps_window = None + self.layout = QVBoxLayout(self) + self.graphs = [ + self.make_graph(dot_string) for dot_string, _, _ in self.algorithms_steps + ] + self.current_graph = 1 + + # graph section + self.graph_group = QGroupBox() + self.graph_group.setTitle("Graph") + self.graph_group_layout = QVBoxLayout(self.graph_group) + self.graph_group_layout.addWidget(self.graphs[self.current_graph - 1]) + self.layout.addWidget(self.graph_group, 1) + + # control panel + self.control_panel_layout = QHBoxLayout() + self.left_button = QPushButton("PREV") + self.left_button.clicked.connect(partial(self.click_listener, "prev")) + self.right_button = QPushButton("NEXT") + self.right_button.clicked.connect(partial(self.click_listener, "next")) + self.num_label = QLabel(f"Tree {self.current_graph}") + self.description = QLabel( + "Visualization of tree from the random forest.\nBlue color of edge means fulfillment condition in node.\nOrange edge leads to the subtree for unfulfilled condition.\nOval nodes are leaves. Each color is related with dominant class." + ) + self.random_button = QPushButton("Random graph") + self.random_button.clicked.connect(partial(self.click_listener, "random")) + self.steps_button = QPushButton("Creation steps") + self.steps_button.clicked.connect(partial(self.click_listener, "steps")) + self.control_panel_layout.addWidget(self.description) + self.control_panel_layout.addStretch() + self.control_panel_layout.addWidget(self.num_label) + self.control_panel_layout.addStretch() + self.control_panel_layout.addWidget(self.steps_button) + self.control_panel_layout.addWidget(self.random_button) + self.control_panel_layout.addWidget(self.left_button) + self.control_panel_layout.addWidget(self.right_button) + self.layout.addLayout(self.control_panel_layout, 0) + + @staticmethod + def postprocess_label(label: Optional[str]): + if label is None: + return "" + label = label[1:-1] + label = label.replace(">", ">") + if "
" not in label: + return label + return "\n".join(label.split("
")) + + def dict_of_param(self, data: Optional[str]) -> Dict: + if data is None: + return {} + res = {} + params = data.split(", ") + for param in params: + key, value = param.split("=", 1) + if key == "fillcolor": + value = value[1:-1] + if key == "label": + value = self.postprocess_label(value) + res[key] = value + return res + + def make_graph(self, data: str): + graph_info = pgv.AGraph(data) + graph_info.layout(prog="dot") + poses = [ + float(node.attr["pos"].split(",")[1]) for node in graph_info.nodes_iter() + ] + pos_delta = min(poses) + max(poses) + 20 + lines = data.split("{", 1)[1].rsplit("}", 1)[0].split("\n") + qgv = QGraphViz(auto_freeze=True, hilight_Nodes=True) + qgv.setStyleSheet("background-color:white;") + qgv.new( + Dot( + Graph("graph", graph_type=GraphType.DirectedGraph), + font=QFont("Helvetica", 12), + margins=[20, 20], + ) + ) + nodes = {} + for line in lines: + if not line: + continue + if line.startswith("node") or line.startswith("edge"): + continue + param = None + if "[" in line: + value, param = line.split("[", 1) + param = param.rsplit("]", 1)[0] + else: + value = line.rsplit(";", 1)[0] + value_list = value.split() + if len(value_list) == 1: + node = value_list[0] + param_dict = self.dict_of_param(param) + if "label" in param_dict.keys(): + label_lines = param_dict["label"].split("\n") + max_len = max(label_lines, key=len) + rect = qgv.engine.fm.boundingRect(max_len) + width = rect.width() + 20 + height = rect.height() * len(label_lines) + 20 + param_dict["size"] = (width, height) + param_dict["pos"] = [ + float(value) + for value in graph_info.get_node(node).attr["pos"].split(",") + ] + param_dict["pos"][0] = int(param_dict["pos"][0]) + param_dict["pos"][1] = int(pos_delta - param_dict["pos"][1]) + nodes[node] = qgv.addNode(qgv.engine.graph, str(node), **param_dict) + elif len(value_list) == 3: + node1 = value_list[0] + node2 = value_list[2] + param_dict = self.dict_of_param(param) + qgv.addEdge(nodes[node1], nodes[node2], param_dict) + else: + raise ValueError("Invalid format of dot string") + qgv.build() + return qgv + + def click_listener(self, button_type: str): + match button_type: + case "steps": + data = self.algorithms_steps[self.current_graph - 1] + self.steps_window = TreeStepsVisualization( + self, data[1], data[2], self.is_animation + ) + self.steps_window.show() + return + case "next": + new_graph = min(len(self.graphs), self.current_graph + 1) + if new_graph == self.current_graph: + return + self.current_graph = new_graph + case "prev": + new_graph = max(1, self.current_graph - 1) + if new_graph == self.current_graph: + return + self.current_graph = new_graph + case "random": + self.current_graph = randint(1, len(self.graphs)) + self.update_graph() + + def update_graph(self): + for i in reversed(range(self.graph_group_layout.count())): + self.graph_group_layout.itemAt(i).widget().setParent(None) + self.graph_group_layout.addWidget(self.graphs[self.current_graph - 1]) + self.num_label.setText(f"Tree {self.current_graph}") + self.num_label.update() diff --git a/src/widgets/steps_widgets/gmm_vis.py b/src/widgets/steps_widgets/gmm_vis.py new file mode 100644 index 0000000..0eb7e0a --- /dev/null +++ b/src/widgets/steps_widgets/gmm_vis.py @@ -0,0 +1,107 @@ +import numpy as np +from matplotlib.animation import FuncAnimation +from PyQt5.QtWidgets import QHBoxLayout + +from visualization import ClusteringCanvas +from widgets.components import ClusteringStepsTemplate +from widgets.steps_widgets import AlgorithmStepsVisualization + + +class GMMStepsVisualization(AlgorithmStepsVisualization): + def __init__(self, data, algorithms_steps, is_animation): + super().__init__(data, algorithms_steps, is_animation) + + description = ( + "Gaussian Mixture Models algorithm - steps visualization.\n\n" + "Colors of points show division into clusters.\n\n" + "Square points represents means of each distribution and " + "ellipses are showing variances." + ) + + self.layout = QHBoxLayout(self) + self.num_cluster = np.amax(self.algorithms_steps[0][0]) + 1 + self.max_step = len(self.algorithms_steps) - 1 + self.columns = list(self.data.select_dtypes(include=["number"]).columns) + + self.canvas = ClusteringCanvas(self.is_animation) + + self.clustering_template = ClusteringStepsTemplate( + self.columns, + self.max_step, + self.data.shape[0], + description, + self.is_animation, + self.canvas, + self.get_func_animation, + ) + self.clustering_template.parameters_changed.connect(self.update_plot) + self.update_plot() + self.layout.addWidget(self.clustering_template) + + def get_func_animation(self): + return FuncAnimation( + self.canvas.figure, + self.update_plot, + frames=self.max_step + 1, + interval=self.clustering_template.interval_box.value(), + blit=True, + cache_frame_data=False, + repeat=False, + ) + + def update_plot(self, step: int = -1): + if step == self.max_step: + self.clustering_template.end_animation() + if step == -1: + step = self.clustering_template.current_step + else: + self.clustering_template.current_step = step + self.clustering_template.update_step_label() + + samples = self.clustering_template.samples + ox = self.clustering_template.ox + oy = self.clustering_template.oy + is_running = self.clustering_template.is_running + + samples_data = self.data.iloc[samples] + x = samples_data[ox] + y = samples_data[oy] + min_x = self.data[ox].min() + max_x = self.data[ox].max() + min_y = self.data[oy].min() + max_y = self.data[oy].max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + if step == 0: + return self.canvas.data_plot( + x, + y, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + + index = step - 1 + step_labels, mean, sigma = self.algorithms_steps[index] + labels = [step_labels[sample] for sample in samples] + return self.canvas.clusters_plot( + x, + y, + list(self.columns), + mean, + sigma, + labels, + self.num_cluster, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) diff --git a/src/widgets/steps_widgets/k_means_vis.py b/src/widgets/steps_widgets/k_means_vis.py new file mode 100644 index 0000000..28d7afe --- /dev/null +++ b/src/widgets/steps_widgets/k_means_vis.py @@ -0,0 +1,198 @@ +from typing import List, Tuple + +import numpy as np +import pandas as pd +from matplotlib.animation import FuncAnimation +from PyQt5.QtWidgets import QHBoxLayout + +from visualization import ClusteringCanvas +from widgets.components import ClusteringStepsTemplate +from widgets.steps_widgets import AlgorithmStepsVisualization + + +class KMeansStepsVisualization(AlgorithmStepsVisualization): + def __init__( + self, + data: pd.DataFrame, + algorithms_steps: List[Tuple[np.ndarray, pd.DataFrame]], + is_animation: bool, + ): + super().__init__(data, algorithms_steps, is_animation) + + description = ( + "K-Means algorithm - steps visualization.\n\n" + "Each color represents one cluster.\n\n" + "Circles are the points of the data set.\n" + "Squares are centroids of the clusters." + ) + + self.layout = QHBoxLayout(self) + self.num_cluster = algorithms_steps[0][1].shape[0] + self.max_step = (len(algorithms_steps) - 1) * (2 + self.num_cluster) + 2 + columns = list(self.data.select_dtypes(include=["number"]).columns) + + self.canvas = ClusteringCanvas(self.is_animation) + + self.clustering_template = ClusteringStepsTemplate( + columns, + self.max_step, + self.data.shape[0], + description, + self.is_animation, + self.canvas, + self.get_func_animation, + ) + self.clustering_template.parameters_changed.connect(self.update_plot) + self.update_plot() + self.layout.addWidget(self.clustering_template) + + def get_func_animation(self): + return FuncAnimation( + self.canvas.figure, + self.update_plot, + frames=self.max_step + 1, + interval=self.clustering_template.interval_box.value(), + blit=True, + cache_frame_data=False, + repeat=False, + ) + + def update_plot(self, step: int = -1): + if step == self.max_step: + self.clustering_template.end_animation() + if step == -1: + step = self.clustering_template.current_step + else: + self.clustering_template.current_step = step + self.clustering_template.update_step_label() + + samples = self.clustering_template.samples + ox = self.clustering_template.ox + oy = self.clustering_template.oy + is_running = self.clustering_template.is_running + + samples_data = self.data.iloc[samples] + x = samples_data[ox] + y = samples_data[oy] + min_x = self.data[ox].min() + max_x = self.data[ox].max() + min_y = self.data[oy].min() + max_y = self.data[oy].max() + sep_x = 0.1 * (max_x - min_x) + sep_y = 0.1 * (max_y - min_y) + + if step == 0: + return self.canvas.data_plot( + x, + y, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + + step_labels, step_centroids = self.algorithms_steps[0] + labels = [step_labels[sample] for sample in samples] + x_centroids = step_centroids[ox] + y_centroids = step_centroids[oy] + + if step == 1: + return self.canvas.new_centroids_plot( + None, + None, + x_centroids, + y_centroids, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + + if step == 2: + return self.canvas.all_plot( + x, + y, + x_centroids, + y_centroids, + labels, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + + index = (step - 3) // (self.num_cluster + 2) + 1 + mode = (step - 3) % (self.num_cluster + 2) + + step_labels, step_centroids = self.algorithms_steps[index] + labels = [step_labels[sample] for sample in samples] + x_centroids = step_centroids[ox] + y_centroids = step_centroids[oy] + + old_step_labels, old_step_centroids = self.algorithms_steps[index - 1] + old_x_centroids = old_step_centroids[ox] + old_y_centroids = old_step_centroids[oy] + + if mode < self.num_cluster: + old_labels = np.array([old_step_labels[sample] for sample in samples]) + vector_x = x.loc[old_labels == mode] + vector_y = y.loc[old_labels == mode] + other_x = x.loc[old_labels != mode] + other_y = y.loc[old_labels != mode] + return self.canvas.chosen_centroid_plot( + vector_x, + vector_y, + other_x, + other_y, + old_x_centroids.iloc[mode], + old_y_centroids.iloc[mode], + x_centroids.iloc[mode], + y_centroids.iloc[mode], + mode, + len(x_centroids), + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + elif mode == self.num_cluster: + return self.canvas.new_centroids_plot( + old_x_centroids, + old_y_centroids, + x_centroids, + y_centroids, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) + else: + return self.canvas.all_plot( + x, + y, + x_centroids, + y_centroids, + labels, + ox, + oy, + min_x - sep_x, + max_x + sep_x, + min_y - sep_y, + max_y + sep_y, + not is_running, + ) diff --git a/src/widgets/table_model.py b/src/widgets/table_model.py new file mode 100644 index 0000000..065d471 --- /dev/null +++ b/src/widgets/table_model.py @@ -0,0 +1,34 @@ +from PyQt5 import QtCore +from PyQt5.QtCore import QAbstractTableModel, Qt + + +class QtTable(QAbstractTableModel): + def __init__(self, data): + QAbstractTableModel.__init__(self) + self._data = data + + def rowCount(self, parent=None): + return self._data.shape[0] + + def columnCount(self, parent=None): + return self._data.shape[1] + + def data(self, index, role=Qt.DisplayRole): + if index.isValid(): + if role == Qt.DisplayRole: + return str(self._data.iloc[index.row(), index.column()]) + return None + + @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) + def headerData( + self, + section: int, + orientation: QtCore.Qt.Orientation, + role: int = QtCore.Qt.DisplayRole, + ): + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._data.columns[section] + else: + return str(self._data.index[section]) + return QtCore.QVariant() diff --git a/src/widgets/tables/__init__.py b/src/widgets/tables/__init__.py new file mode 100644 index 0000000..d2b2ee8 --- /dev/null +++ b/src/widgets/tables/__init__.py @@ -0,0 +1,2 @@ +from .data_preview_screen import DataPreviewScreen, PreviewReason +from .merging_sets_screen import MergingSetsScreen diff --git a/src/widgets/tables/data_preview_screen.py b/src/widgets/tables/data_preview_screen.py new file mode 100644 index 0000000..112a27f --- /dev/null +++ b/src/widgets/tables/data_preview_screen.py @@ -0,0 +1,135 @@ +from enum import Enum +from time import strptime + +from PyQt5.QtWidgets import ( + QInputDialog, + QLabel, + QLineEdit, + QMessageBox, + QSizePolicy, + QTableView, + QVBoxLayout, + QWidget, +) + +from widgets import QtTable + + +class PreviewReason(Enum): + ESTIMATION = "estimation" + REDUCTION = "reduction" + PREVIEW = "preview" + + +def fallback(): + pass + + +class DataPreviewScreen(QWidget): + def __init__( + self, + widget, + title="Data preview", + reason: PreviewReason = PreviewReason.PREVIEW, + ): + super().__init__() + self.parent = widget + self.engine = self.parent.engine + self.setWindowTitle(title) + + self.layout = QVBoxLayout() + self.data_table = QTableView() + + match reason: + case PreviewReason.ESTIMATION: + self.render_instruction( + "Double click on a cell to fill in a missing value or on a header to affect " + "the entire column" + ) + self.render_data( + self.estimation_header_click, self.estimation_cell_click + ) + case PreviewReason.REDUCTION: + self.render_instruction( + "Double click on a header to change column name" + ) + self.render_data(self.reduction_header_click, fallback) + case PreviewReason.PREVIEW: + self.render_data(fallback, fallback) + + self.setLayout(self.layout) + + def render_instruction(self, instruction): + instruction_widget = QLabel(instruction) + self.layout.addWidget(instruction_widget) + + def render_data(self, handle_header_click, handle_cell_click): + self.data_table.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.data_table.setModel(QtTable(self.engine.state.imported_data)) + self.data_table.horizontalHeader().sectionDoubleClicked.connect( + handle_header_click + ) + self.data_table.doubleClicked.connect(handle_cell_click) + self.layout.addWidget(self.data_table) + + def reduction_header_click(self, index): + new_header, ok = QInputDialog.getText( + self, + f"Change header label for column {index}", + "Header:", + QLineEdit.Normal, + "", + ) + if ok: + self.engine.rename_column(index, new_header) + self.data_table.setModel(QtTable(self.engine.state.imported_data)) + + def estimation_header_click(self, index): + new_value, ok = QInputDialog.getText( + self, f"Default value for column {index}:", "Value:", QLineEdit.Normal, "" + ) + if ok: + column = self.engine.state.imported_data.iloc[:, index] + new_value = self.cast_input_type(column.dtype, new_value) + if new_value is not None: + column.fillna(new_value, inplace=True) + self.data_table.setModel(QtTable(self.engine.state.imported_data)) + + def estimation_cell_click(self): + cell = self.data_table.selectionModel().selectedIndexes()[0] + row, col = cell.row(), cell.column() + new_value, ok = QInputDialog.getText( + self, + f"Default value for cell ({row}, {col}):", + "Value:", + QLineEdit.Normal, + "", + ) + if ok: + column = self.engine.state.imported_data.iloc[:, col] + new_value = self.cast_input_type(column.dtype, new_value) + if new_value is not None: + self.engine.state.imported_data.iloc[row, col] = new_value + self.data_table.setModel(QtTable(self.engine.state.imported_data)) + + @staticmethod + def cast_input_type(data_type, value): + try: + match data_type: + case "int32" | "int64": + return int(value) + case "float64": + return float(value) + case "datetime64[ns]": + return strptime(value) + case _: + return value + except ValueError: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("Data is not valid") + error.setWindowTitle("Error") + error.exec_() + + def closeEvent(self, event): + self.parent.get_data() diff --git a/src/widgets/tables/merging_sets_screen.py b/src/widgets/tables/merging_sets_screen.py new file mode 100644 index 0000000..a9b7318 --- /dev/null +++ b/src/widgets/tables/merging_sets_screen.py @@ -0,0 +1,394 @@ +from functools import partial + +import numpy as np +import pandas as pd +from PyQt5.QtCore import QMimeData, Qt +from PyQt5.QtGui import QDrag, QPixmap +from PyQt5.QtWidgets import ( + QBoxLayout, + QComboBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QStyle, + QTableView, + QVBoxLayout, + QWidget, +) + +from data_import import Loader +from widgets import LoadingWidget, QtTable + + +class DragButton(QPushButton): + def mouseMoveEvent(self, e): + if e.buttons() == Qt.LeftButton: + drag = QDrag(self) + mime = QMimeData() + drag.setMimeData(mime) + pixmap = QPixmap(self.size()) + self.render(pixmap) + drag.setPixmap(pixmap) + drag.exec_(Qt.MoveAction) + + +class ColumnWidget(QWidget): + def __init__(self, column_name, parent_layout): + super().__init__() + self.column_name = column_name + self.parent_layout = parent_layout + + self.layout = QHBoxLayout(self) + self.drag_button = DragButton(self.column_name) + + self.layout.addWidget(self.drag_button, stretch=1) + + self.remove_button = QPushButton() + pixmap = getattr(QStyle, "SP_DialogCancelButton") + icon = self.style().standardIcon(pixmap) + self.remove_button.setIcon(icon) + self.remove_button.setFixedWidth(30) + self.remove_button.clicked.connect(self._on_remove) + self.layout.addWidget(self.remove_button) + + def _on_remove(self): + for i in reversed(range(self.parent_layout.count() - 1)): + if ( + self.parent_layout.itemAt(i).widget().drag_button.text() + == self.column_name + ): + self.parent_layout.itemAt(i).widget().setParent(None) + break + + +class MergingSetsScreen(QWidget): + def __init__(self, widget, on_hide): + super().__init__() + self.widget = widget + self.engine = widget.engine + self.loader = Loader() + self.setAcceptDrops(True) + self.new_data = None + self.setWindowTitle("Datasets concatenation") + self._load_styles() + self.drag_init_pos = None + self.on_hide_callback = on_hide + self.cancel_merge = False + + # init all components to have a ref + self.layout = QHBoxLayout() + self.current_data_view = QTableView() + self.new_data_view = QTableView() + self.load_data_group = QGroupBox(self) + self.load_data_group_layout = QFormLayout(self.load_data_group) + self.filepath_label = QLabel(self.load_data_group) + self.filepath_line = QLineEdit(self.load_data_group) + self.file_button = QPushButton(self.load_data_group) + self.database_label = QLabel(self.load_data_group) + self.database_box = QComboBox(self.load_data_group) + self.database_button = QPushButton(self.load_data_group) + self.import_state_label = QLabel(self.load_data_group) + self.columns_merge_group = QGroupBox(self) + self.columns_merge_group_layout = QVBoxLayout(self.columns_merge_group) + self.columns_widget = QWidget() + self.columns_layout = QHBoxLayout() + self.left_columns = QWidget() + self.right_columns = QWidget() + self.columns_left_layout = QVBoxLayout() + self.columns_right_layout = QVBoxLayout() + self.submit_button = QPushButton() + + # layouts stretch settings + self.columns_left_layout.addStretch(QBoxLayout.BottomToTop) + self.columns_right_layout.addStretch(QBoxLayout.BottomToTop) + + # rendering content + self._render_table(self.current_data_view, True) + self._render_panel() + self._render_table(self.new_data_view, False) + + self.setLayout(self.layout) + + def _load_styles(self): + with open("../static/css/styles.css") as stylesheet: + self.setStyleSheet(stylesheet.read()) + + def _render_table(self, table_widget, current): + table_group = QGroupBox(self) + table_group_layout = QVBoxLayout(table_group) + table_group.setTitle("Original table" if current else "Table to merge") + if current: + table_widget.setModel(QtTable(self.engine.state.imported_data)) + else: + data = self.new_data if self.new_data is not None else pd.DataFrame() + table_widget.setModel(QtTable(data)) + table_group_layout.addWidget(table_widget) + self.layout.addWidget(table_group, 1) + + def _render_panel(self): + panel = QWidget() + panel_layout = QVBoxLayout(panel) + self._render_import_view() + self._render_columns_view() + panel_layout.addWidget(self.load_data_group) + panel_layout.addWidget(self.columns_merge_group, 1) + self.layout.addWidget(panel, 1) + + def _render_import_view(self): + self.load_data_group_layout.setFieldGrowthPolicy( + QFormLayout.AllNonFixedFieldsGrow + ) + self.load_data_group.setTitle("Load data") + + self.filepath_label.setText("Load data from file:") + self.load_data_group_layout.addRow(self.filepath_label) + + self.filepath_line.setReadOnly(True) + self.filepath_line.setFixedWidth(150) + + self.file_button.setText("Select file") + self.file_button.clicked.connect(partial(self._click_listener, "load_file")) + self.load_data_group_layout.addRow(self.filepath_line, self.file_button) + + self.database_label.setText("Choose data from database:") + self.load_data_group_layout.addRow(self.database_label) + + names = self.engine.get_table_names_from_database() + for name in names: + self.database_box.addItem(name) + + self.database_button.setText("Load") + self.database_button.clicked.connect( + partial(self._click_listener, "load_database") + ) + self.load_data_group_layout.addRow(self.database_box, self.database_button) + + self.load_data_group_layout.addRow(self.import_state_label) + + def _render_columns_view(self): + self.columns_merge_group.setTitle("Merge columns") + instruction = QLabel() + instruction.setText( + "Firstly, you need to import another dataset here.\nIf everything goes well, the new " + "columns should appear inside the right empty box.\nYou should see two lists of columns " + "and you can set which columns will be merged together by drag and drop.\nWhen you're " + "ready, click the 'submit' button, then another screen with results will be shown.\n" + "You will be asked whether you'd like to accept or reject integration.\nIf you choose to " + "concatenate, the newly created dataset should be visible on the import screen." + ) + self.columns_merge_group_layout.addWidget(instruction, 0) + + imported_columns = self._get_imported_columns() + for column in imported_columns: + self.columns_left_layout.insertWidget( + self.columns_left_layout.count() - 1, + ColumnWidget(column, self.columns_left_layout), + ) + + if self.new_data is not None and self.new_data.columns is not None: + for column in sorted(self.new_data.columns, key=lambda x: x.upper()): + self.columns_right_layout.insertWidget( + self.columns_right_layout.count() - 1, + ColumnWidget(column, self.columns_right_layout), + ) + + self.left_columns.setLayout(self.columns_left_layout) + self.right_columns.setLayout(self.columns_right_layout) + self.columns_layout.addWidget(self.left_columns) + self.columns_layout.addWidget(self.right_columns) + self.columns_widget.setLayout(self.columns_layout) + self.columns_merge_group_layout.addWidget(self.columns_widget, 1) + + self.submit_button.setText("Submit") + self.submit_button.setEnabled(self.new_data is not None) + self.submit_button.clicked.connect(partial(self._click_listener, "submit")) + self.columns_merge_group_layout.addWidget(self.submit_button, 0) + + def _click_listener(self, button_type: str): + match button_type: + case "load_file": + loading = LoadingWidget(self._load_from_file_handle) + loading.execute() + case "load_database": + loading = LoadingWidget(self._load_from_database_handle) + loading.execute() + case "submit": + loading = LoadingWidget(self._on_submit) + loading.execute() + + def _load_from_file_handle(self): + self.import_state_label.setText("Loading ...") + file_path = QFileDialog.getOpenFileName( + self, "Choose file", ".", "*.csv *.json" + )[0] + try: + reader = self.loader.create_file_reader(file_path) + except ValueError as e: + self.import_state_label.setText(str(e)) + else: + self._on_success(reader, file_path) + + def _load_from_database_handle(self): + self.import_state_label.setText("Loading ...") + document_name = self.database_box.currentText() + try: + reader = self.loader.create_database_reader(document_name) + except ValueError as e: + self.import_state_label.setText(str(e)) + else: + self._on_success(reader) + + def _on_success(self, reader, file_path=None): + if file_path is not None: + self.filepath_line.setText(file_path) + self.import_state_label.clear() + if self.engine.is_data_big(): + error = QMessageBox() + error.setIcon(QMessageBox.Warning) + error.setText("This file is too big.\nYou must save it in database!") + error.setWindowTitle("Warning") + error.exec_() + self.new_data = reader.read(None) + last_preview = self.layout.takeAt(2).widget() + if last_preview is not None: + last_preview.deleteLater() + self._render_table(self.new_data_view, False) + + for i in reversed(range(self.columns_right_layout.count() - 1)): + self.columns_right_layout.itemAt(i).widget().setParent(None) + + if self.new_data is not None and self.new_data.columns is not None: + for column in sorted(self.new_data.columns, key=lambda x: x.upper()): + self.columns_right_layout.insertWidget( + self.columns_right_layout.count() - 1, + ColumnWidget(column, self.columns_right_layout), + ) + + self.submit_button.setEnabled(self.new_data is not None) + + def _on_submit(self): + new_columns_left = [] + new_columns_right = [] + for i in range(self.columns_left_layout.count() - 1): + new_columns_left.append( + self.columns_left_layout.itemAt(i).widget().column_name + ) + for i in range(self.columns_right_layout.count() - 1): + new_columns_right.append( + self.columns_right_layout.itemAt(i).widget().column_name + ) + + if len(new_columns_left) != len(new_columns_right): + self._render_equality_warning() + if self.cancel_merge: + self.cancel_merge = False + return + + dropped_columns = [ + column + for column in self.new_data.columns + if column not in new_columns_right + ] + + overflowed_columns = new_columns_right[len(new_columns_left) :] + + for column in dropped_columns + overflowed_columns: + new_column_name = f"{column}_new" + self.new_data.rename(columns={column: new_column_name}, inplace=True) + + labels_mapping = dict(zip(new_columns_right, new_columns_left)) + self.new_data.rename(columns=labels_mapping, inplace=True) + + self.engine.merge_sets(self.new_data) + self.widget.set_columns_grid() + + self.on_hide_callback() + self.hide() + + def _get_imported_columns(self): + return sorted(self.engine.state.imported_data.columns, key=lambda x: x.upper()) + + def closeEvent(self, event): + close = QMessageBox.question( + self, + "Exit", + "Are you sure want to exit process? All changes will be discarded.", + QMessageBox.Yes | QMessageBox.No, + ) + if close == QMessageBox.Yes: + event.accept() + else: + event.ignore() + + def dragEnterEvent(self, e): + self.drag_init_pos = e.pos() + e.accept() + + def dropEvent(self, e): + pos = e.pos() + widget = e.source().parent() + + widget_helper = self.columns_left_layout.itemAt(0).widget() + should_insert = False + if ( + pos.x() + < widget_helper.mapToGlobal(widget_helper.rect().topLeft()).x() + - self.pos().x() + + widget_helper.size().width() + ): + for n in range(self.columns_left_layout.count()): + w = self.columns_left_layout.itemAt(n).widget() + if w == widget: + should_insert = True + break + if should_insert: + for n in range(self.columns_left_layout.count()): + w = self.columns_left_layout.itemAt(n).widget() + if w is None: + self.columns_left_layout.insertWidget(n - 1, widget) + break + elif ( + pos.y() < w.mapToGlobal(w.rect().topLeft()).y() - self.pos().y() + ): + self.columns_left_layout.insertWidget(n, widget) + break + + elif self.new_data is not None: + for n in range(self.columns_right_layout.count()): + w = self.columns_right_layout.itemAt(n).widget() + if w == widget: + should_insert = True + break + if should_insert: + for n in range(self.columns_right_layout.count()): + w = self.columns_right_layout.itemAt(n).widget() + if w is None: + self.columns_right_layout.insertWidget(n - 1, widget) + break + elif ( + pos.y() < w.mapToGlobal(w.rect().topLeft()).y() - self.pos().y() + ): + self.columns_right_layout.insertWidget(n, widget) + break + + e.accept() + + def _render_equality_warning(self): + warning = QMessageBox() + warning.setIcon(QMessageBox.Warning) + warning.setText("The numbers of dimensions are not the same") + warning.setInformativeText( + "You can let the system fill the dataset or you can cancel submission and try to remove some columns. Do you want to add extra columns with nulls?" + ) + warning.setWindowTitle("Dimensions are not the same") + warning.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) + warning.buttonClicked.connect(self._on_warning_submit_callback) + warning.exec_() + + def _on_warning_submit_callback(self, button): + self.cancel_merge = "Cancel" in button.text() diff --git a/src/widgets/unfold_widgets/__init__.py b/src/widgets/unfold_widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/widgets/unfold_widgets/algorithm_run_widget.py b/src/widgets/unfold_widgets/algorithm_run_widget.py new file mode 100644 index 0000000..a18f149 --- /dev/null +++ b/src/widgets/unfold_widgets/algorithm_run_widget.py @@ -0,0 +1,44 @@ +from PyQt5.QtWidgets import QMessageBox, QVBoxLayout + +from widgets import UnfoldWidget + + +class AlgorithmRunWidget(UnfoldWidget): + def __init__(self, parent, engine): + super().__init__(parent, engine, "algorithm_run_widget", "ALGORITHM RUN") + + self.button.disconnect() + self.button.clicked.connect(self.load_widget) + + # layout + self.layout = QVBoxLayout(self.frame) + + def load_widget(self): + if self.engine.state.steps_visualization is None: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + if self.engine.state.imported_data is None: + error.setText("No dataset was selected") + else: + error.setText("Steps visualization is disabled") + error.setWindowTitle("Error") + error.exec_() + return + + self.clear_layout() + self.layout.addWidget(self.engine.state.steps_visualization) + self.layout.update() + + self.parent().unfold(self) + + def clear_layout(self, layout=None): + if layout is None: + layout = self.layout + for i in reversed(range(layout.count())): + child = layout.itemAt(i) + if child.widget(): + child.widget().setParent(None) + elif child.layout(): + self.clear_layout(child.layout()) + else: + layout.removeItem(child) diff --git a/src/widgets/unfold_widgets/algorithm_setup_widget.py b/src/widgets/unfold_widgets/algorithm_setup_widget.py new file mode 100644 index 0000000..378ba93 --- /dev/null +++ b/src/widgets/unfold_widgets/algorithm_setup_widget.py @@ -0,0 +1,192 @@ +from functools import partial + +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import ( + QComboBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QMessageBox, + QPushButton, + QVBoxLayout, +) + +from widgets import LoadingWidget, UnfoldWidget + + +class AlgorithmSetupWidget(UnfoldWidget): + def __init__(self, parent, engine): + super().__init__(parent, engine, "algorithm_setup_widget", "ALGORITHM SETUP") + + self.button.disconnect() + self.button.clicked.connect(self.load_widget) + + # layout + self.layout = QVBoxLayout(self.frame) + + # layouts for sections + self.vertical_layout = QHBoxLayout() + self.first_column = QVBoxLayout() + self.second_column = QVBoxLayout() + self.third_column = QVBoxLayout() + + # exploration technique selection + self.technique_group = QGroupBox() + self.technique_group.setTitle("Exploration technique") + self.technique_group.setMinimumSize(220, 95) + self.technique_group_layout = QFormLayout(self.technique_group) + self.technique_group_layout.setFormAlignment(Qt.AlignVCenter) + + self.technique_box = QComboBox() + self.technique_box.addItems(self.engine.get_all_techniques()) + self.technique_box.currentTextChanged.connect( + partial(self.click_listener, "technique") + ) + self.technique_group_layout.addRow(self.technique_box) + + # algorithm selection group + self.algorithm_selection_group = QGroupBox() + self.algorithm_selection_group.setTitle("Algorithm") + self.algorithm_selection_group.setMinimumSize(220, 95) + self.algorithm_selection_group_layout = QFormLayout( + self.algorithm_selection_group + ) + self.algorithm_selection_group_layout.setFormAlignment(Qt.AlignVCenter) + + self.algorithm_box = QComboBox() + self.algorithm_box.addItems( + self.engine.get_algorithms_for_techniques(self.technique_box.currentText()) + ) + self.algorithm_box.currentTextChanged.connect( + partial(self.click_listener, "algorithm") + ) + self.algorithm_selection_group_layout.addRow(self.algorithm_box) + + self.first_column.addWidget(self.technique_group) + self.first_column.addStretch(1) + self.first_column.addWidget(self.algorithm_selection_group) + self.first_column.addStretch(3) + + # options group + self.options_group = QGroupBox() + self.options_group.setTitle("Options") + self.options_group.setMinimumSize(400, 270) + self.options_group_layout = QVBoxLayout(self.options_group) + + self.options_group_layout.addWidget( + self.engine.get_option_widget( + self.technique_box.currentText(), self.algorithm_box.currentText() + ) + ) + + # animation group + self.animation_group = QGroupBox() + self.animation_group.setTitle("Animation") + self.animation_group.setMinimumSize(220, 65) + self.animation_group_layout = QHBoxLayout(self.animation_group) + + self.animation_type = QComboBox() + self.animation_type.addItems(["Step by step", "Animation", "No visualization"]) + self.animation_type.setFixedWidth(175) + + self.animation_group_layout.addStretch() + self.animation_group_layout.addWidget(QLabel("Visualization type:")) + self.animation_group_layout.addWidget(self.animation_type) + + self.second_column.addWidget(self.options_group) + self.second_column.addSpacing(10) + self.second_column.addWidget(self.animation_group) + + # description group + self.algorithm_description_group = QGroupBox() + self.algorithm_description_group.setTitle("Description") + self.algorithm_description_group.setFixedWidth(280) + self.algorithm_description_group.setMinimumHeight(265) + self.algorithm_description_group_layout = QVBoxLayout( + self.algorithm_description_group + ) + self.algorithm_description = QLabel() + self.algorithm_description.setText( + self.engine.get_algorithm_description( + self.technique_box.currentText(), self.algorithm_box.currentText() + ) + ) + self.algorithm_description.setWordWrap(True) + self.algorithm_description.setAlignment(Qt.AlignJustify) + self.algorithm_description_group_layout.addWidget(self.algorithm_description) + + self.third_column.addWidget(self.algorithm_description_group) + + self.vertical_layout.addStretch(2) + self.vertical_layout.addLayout(self.first_column) + self.vertical_layout.addStretch(1) + self.vertical_layout.addLayout(self.second_column) + self.vertical_layout.addStretch(1) + self.vertical_layout.addLayout(self.third_column) + self.vertical_layout.addStretch(2) + + # button + self.run_button = QPushButton(self.frame) + self.run_button.setText("Submit and run") + self.run_button.setFixedWidth(300) + self.run_button.clicked.connect(partial(self.click_listener, "run")) + + self.layout.addStretch() + self.layout.addLayout(self.vertical_layout) + self.layout.addWidget(self.run_button, alignment=Qt.AlignCenter) + self.layout.addStretch() + + def load_widget(self): + if self.engine.state.imported_data is None: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("No dataset was selected") + error.setWindowTitle("Error") + error.exec_() + return + + self.engine.update_options() + self.parent().unfold(self) + + def click_listener(self, button_type: str): + technique = self.technique_box.currentText() + algorithm = self.algorithm_box.currentText() + match button_type: + case "technique": + self.algorithm_box.clear() + self.algorithm_box.addItems( + self.engine.get_algorithms_for_techniques(technique) + ) + case "algorithm": + for i in reversed(range(self.options_group_layout.count())): + self.options_group_layout.itemAt(i).widget().setParent(None) + if algorithm: + self.options_group_layout.addWidget( + self.engine.get_option_widget( + technique, + algorithm, + ) + ) + self.algorithm_description.setText( + self.engine.get_algorithm_description(technique, algorithm) + ) + case "run": + loading = LoadingWidget(self.run_handle) + loading.execute() + + def run_handle(self): + technique = self.technique_box.currentText() + algorithm = self.algorithm_box.currentText() + data = self.engine.get_option_widget(technique, algorithm).get_data() + type_visualization = self.animation_type.currentText() + will_be_visualized = type_visualization != "No visualization" + is_animation = type_visualization == "Animation" + is_run = self.engine.run( + technique, algorithm, will_be_visualized, is_animation, **data + ) + if is_run: + if will_be_visualized: + self.parent().unfold_by_id("algorithm_run_widget") + else: + self.parent().unfold_by_id("results_widget") diff --git a/src/widgets/unfold_widgets/import_widget.py b/src/widgets/unfold_widgets/import_widget.py new file mode 100644 index 0000000..a72efcb --- /dev/null +++ b/src/widgets/unfold_widgets/import_widget.py @@ -0,0 +1,325 @@ +from functools import partial +from os.path import basename +from typing import List + +from PyQt5.QtWidgets import ( + QCheckBox, + QComboBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QScrollArea, + QSizePolicy, + QSpinBox, + QTableView, + QVBoxLayout, + QWidget, +) + +from widgets import MergingSetsScreen, LoadingWidget, QtTable, UnfoldWidget +from widgets.data_generator_widget import DataGeneratorWidget + + +class ImportWidget(UnfoldWidget): + def __init__(self, parent: QWidget, engine): + super().__init__(parent, engine, "import_widget", "IMPORT DATA") + + self.new_window = None + + # load data group + self.load_data_group = QGroupBox(self.frame) + self.load_data_group_layout = QFormLayout(self.load_data_group) + self.load_data_group_layout.setFieldGrowthPolicy( + QFormLayout.AllNonFixedFieldsGrow + ) + self.load_data_group.setTitle("Load data") + + self.filepath_label = QLabel(self.load_data_group) + self.filepath_label.setText("Load data from file:") + + self.filepath_line = QLineEdit(self.load_data_group) + self.filepath_line.setReadOnly(True) + self.filepath_line.setFixedWidth(150) + + self.load_data_group_layout.addRow(self.filepath_label) + + self.file_button = QPushButton(self.load_data_group) + self.file_button.setText("Select file") + self.file_button.clicked.connect(partial(self.click_listener, "load_file")) + + self.load_data_group_layout.addRow(self.filepath_line, self.file_button) + + self.database_label = QLabel(self.load_data_group) + self.database_label.setText("Choose data from database:") + self.load_data_group_layout.addRow(self.database_label) + + self.database_box = QComboBox(self.load_data_group) + self.set_available_tables() + self.database_button = QPushButton(self.load_data_group) + self.database_button.setText("Load") + self.database_button.clicked.connect( + partial(self.click_listener, "load_database") + ) + + self.load_data_group_layout.addRow(self.database_box, self.database_button) + + self.generate_data_label = QLabel(self.load_data_group) + self.generate_data_label.setText("Generate data for a specific algorithm:") + self.load_data_group_layout.addRow(self.generate_data_label) + + self.generate_button = QPushButton(self.load_data_group) + self.generate_button.setText("Generate") + self.generate_button.clicked.connect(partial(self.click_listener, "generate")) + self.load_data_group_layout.addRow(self.generate_button) + + self.generate_window = DataGeneratorWidget( + self.engine, callback=self.update_data_view + ) + + self.import_state_label = QLabel(self.load_data_group) + self.load_data_group_layout.addRow(self.import_state_label) + + # options group + self.options_group = QGroupBox(self.frame) + self.options_layout = QVBoxLayout(self.options_group) + self.options_group.setTitle("Options") + + self.reject_button = QPushButton(self.options_group) + self.reject_button.setText("Reject this data") + self.reject_button.clicked.connect(partial(self.click_listener, "reject_data")) + self.options_layout.addWidget(self.reject_button, 1) + + self.save_button = QPushButton(self.options_group) + self.save_button.setText("Save to database") + self.save_button.clicked.connect(partial(self.click_listener, "save_data")) + self.save_button.setEnabled(False) + self.options_layout.addWidget(self.save_button, 1) + + self.merge_button = QPushButton(self.options_group) + self.merge_button.setText("Merge another dataset") + self.merge_button.clicked.connect(partial(self.click_listener, "merge_data")) + self.merge_button.setEnabled(False) + self.merge_button.setMinimumHeight(23) + self.options_layout.addWidget(self.merge_button, 1) + + # columns group + self.columns_group = QGroupBox(self.frame) + self.columns_group.setTitle("Limit data") + self.columns_group_layout = QFormLayout(self.columns_group) + + self.limit_type_box = QComboBox() + self.limit_type_box.addItems(["random", "first"]) + self.limit_type_box.setEnabled(False) + self.limit_number_box = QSpinBox() + self.limit_number_box.setMinimum(1) + self.limit_number_box.setEnabled(False) + self.limit_button = QPushButton("Limit number of rows") + self.limit_button.clicked.connect(partial(self.click_listener, "limit_data")) + self.limit_button.setEnabled(False) + + self.scroll_box = QGroupBox(self.frame) + self.columns_group_form_layout = QFormLayout(self.scroll_box) + + self.scroll = QScrollArea() + self.scroll.setWidget(self.scroll_box) + self.scroll.setWidgetResizable(True) + + self.columns_button = QPushButton("Select columns") + self.columns_button.setEnabled(False) + self.columns_button.clicked.connect(partial(self.click_listener, "columns")) + + self.scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.columns_group_layout.addRow(self.limit_type_box, self.limit_number_box) + self.columns_group_layout.addRow(self.limit_button) + self.columns_group_layout.addRow(self.scroll) + self.columns_group_layout.addRow(self.columns_button) + self.columns_group.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + # data table + self.data_table = QTableView(self.frame) + self.data_table.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + # layouts for sections + self.layout = QHBoxLayout(self.frame) + + self.first_column = QVBoxLayout() + self.first_column.addWidget(self.load_data_group, 0) + self.first_column.addWidget(self.options_group, 0) + self.first_column.addWidget(self.columns_group, 1) + + self.first_column.setSpacing(35) + + self.second_column = QVBoxLayout() + self.second_column.addWidget(self.data_table) + + self.layout.addLayout(self.first_column, 0) + self.layout.addLayout(self.second_column, 1) + + def set_available_tables(self): + """set titles to box""" + names = self.engine.get_table_names_from_database() + for name in names: + self.database_box.addItem(name) + + def set_options(self): + """enable buttons after load data""" + self.save_button.setEnabled(True) + self.merge_button.setEnabled(True) + self.columns_button.setEnabled(True) + self.limit_button.setEnabled(True) + self.limit_type_box.setEnabled(True) + self.limit_number_box.setEnabled(True) + if self.engine.is_data_big(): + error = QMessageBox() + error.setIcon(QMessageBox.Warning) + error.setText("This file is too big.\nYou must save it in database!") + error.setWindowTitle("Warning") + error.exec_() + + def clear_widgets(self): + """clear import widget from loaded data""" + self.save_button.setEnabled(False) + self.merge_button.setEnabled(False) + self.columns_button.setEnabled(False) + self.limit_button.setEnabled(False) + self.limit_type_box.setEnabled(False) + self.limit_number_box.setEnabled(False) + self.import_state_label.clear() + self.filepath_line.clear() + for i in reversed(range(self.columns_group_form_layout.count())): + self.columns_group_form_layout.itemAt(i).widget().setParent(None) + + def set_columns_grid(self, columns=None): + """draw columns and checkbox to choose them""" + if columns is None: + columns = self.engine.get_columns() + + for i in reversed(range(self.columns_group_form_layout.count())): + self.columns_group_form_layout.itemAt(i).widget().setParent(None) + + for column in columns: + checkbox = QCheckBox(column) + checkbox.setMinimumHeight(26) + checkbox.setChecked(True) + self.columns_group_form_layout.addRow(checkbox) + + self.columns_button.setEnabled(True) + + def display_data(self): + if (data := self.engine.state.raw_data) is not None: + self.data_table.setModel(QtTable(data)) + self.limit_number_box.setMaximum(len(data)) + self.limit_number_box.setValue(len(data) // 2) + + def update_data_view(self): + if (data := self.engine.state.imported_data) is not None: + self.data_table.setModel(QtTable(data)) + self.limit_number_box.setMaximum(len(data)) + self.limit_number_box.setValue(len(data) // 2) + self.set_columns_grid(data.columns) + self.set_options() + + def reset_data_table(self): + self.data_table.setModel(None) + + def get_checked_columns(self) -> List[str]: + columns = [] + for i in range(self.columns_group_form_layout.count()): + if self.columns_group_form_layout.itemAt(i).widget().isChecked(): + columns.append(self.columns_group_form_layout.itemAt(i).widget().text()) + return columns + + def click_listener(self, button_type: str): + match button_type: + case "load_file": + loading = LoadingWidget(self.load_from_file_handle) + loading.execute() + case "load_database": + loading = LoadingWidget(self.load_from_database_handle) + loading.execute() + case "reject_data": + loading = LoadingWidget(self.reject_data_handle) + loading.execute() + case "save_data": + loading = LoadingWidget(self.save_data_handle) + loading.execute() + case "columns": + columns = self.get_checked_columns() + if not columns: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("No columns were chosen") + error.setWindowTitle("Error") + error.exec_() + return + self.engine.limit_data(columns=self.get_checked_columns()) + self.set_columns_grid() + self.display_data() + case "limit_data": + self.engine.limit_data( + limit_type=self.limit_type_box.currentText(), + limit_num=self.limit_number_box.value(), + ) + self.display_data() + case "generate": + self.generate_window.show() + case "merge_data": + self.new_window = MergingSetsScreen(self, self.update_data_view) + self.new_window.show() + + def load_from_file_handle(self): + self.import_state_label.setText("Loading ...") + file_path: str = QFileDialog.getOpenFileName( + self, "Choose file", ".", "*.csv *.json" + )[0] + try: + self.engine.load_data_from_file(file_path) + except ValueError as e: + self.import_state_label.setText(str(e)) + else: + self._on_success(file_path) + + def load_from_database_handle(self): + self.import_state_label.setText("Loading ...") + document_name = self.database_box.currentText() + try: + self.engine.load_data_from_database(document_name) + except ValueError as e: + self.import_state_label.setText(str(e)) + else: + self._on_success() + + def _on_success(self, file_path=None): + if file_path is not None: + self.filepath_line.setText(basename(file_path)) + self.clear_widgets() + self.set_options() + self.engine.read_data() + self.set_columns_grid() + self.display_data() + + def reject_data_handle(self): + self.clear_widgets() + self.engine.clear_import() + self.reset_data_table() + + def save_data_handle(self): + self.engine.drop_additional_columns() + text, is_ok = QInputDialog.getText( + self, "input name", "Enter name of collection:" + ) + if is_ok: + if text: + label = self.engine.save_to_database(str(text)) + if label: + self.import_state_label.setText(label) + else: + self.import_state_label.setText("Data was stored in database.") + else: + self.import_state_label.setText("The name of collection is not valid.") diff --git a/src/widgets/unfold_widgets/preprocessing_widget.py b/src/widgets/unfold_widgets/preprocessing_widget.py new file mode 100644 index 0000000..c23fe6e --- /dev/null +++ b/src/widgets/unfold_widgets/preprocessing_widget.py @@ -0,0 +1,403 @@ +import matplotlib.pyplot as plt +from PyQt5.QtCore import QRect, Qt +from PyQt5.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QDesktopWidget, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QMessageBox, + QPushButton, + QScrollArea, + QSizePolicy, + QSpinBox, + QSplashScreen, + QVBoxLayout, + QWidget, +) + +from visualization import plots +from widgets import UnfoldWidget +from widgets.components import QLabelWithTooltip, SamplesColumnsChoice +from widgets.tables import DataPreviewScreen, PreviewReason + + +class PreprocessingWidget(UnfoldWidget): + def __init__(self, parent: QWidget, engine): + super().__init__(parent, engine, "preprocessing_widget", "PREPROCESSING") + + self.data_submitted = False + self.mark_reduced_columns = False + self.button.disconnect() + self.button.clicked.connect(lambda: self.get_data()) + + self.plot_types = ["Histogram", "Pie", "Null frequency", "Scatter plot"] + + # plot picker group + self.plot_picker_group = QGroupBox(self.frame) + self.plot_picker_group_layout = QFormLayout(self.plot_picker_group) + + self.plot_picker_group.setTitle("Choose data to plot") + + self.parameters_widget = SamplesColumnsChoice() + self.parameters_widget_connection = None + self.group_picker_label = QLabel("Group by:") + self.group_picker_label.setMinimumHeight(23) + self.group_select_box = QComboBox() + self.group_select_box.setMinimumHeight(23) + + self.plot_picker_group.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) + + self.plot_picker_label = QLabel(self.plot_picker_group) + self.plot_picker_label.setText("Select plot type:") + self.plot_picker_label.setMinimumHeight(23) + self.plot_select_box = QComboBox(self.plot_picker_group) + self.plot_select_box.setMinimumHeight(23) + self.plot_select_box.addItems(self.plot_types) + self.plot_picker_group_layout.addRow( + self.plot_picker_label, self.plot_select_box + ) + + self.column_picker_label = QLabel(self.plot_picker_group) + self.column_picker_label.setText("Select column:") + self.column_picker_label.setMinimumHeight(23) + self.column_select_box = QComboBox(self.plot_picker_group) + self.column_select_box.setMinimumHeight(23) + + self.plot_picker_group_layout.addRow( + self.column_picker_label, self.column_select_box + ) + + self.plot_select_box.currentTextChanged.connect(self.change_settings) + + self.plot_picker_submit = QPushButton(self.plot_picker_group) + self.plot_picker_submit.setText("Plot") + self.plot_picker_submit.setMinimumHeight(23) + self.plot_picker_submit.clicked.connect( + lambda: self.plot_data( + self.column_select_box.currentText(), self.plot_select_box.currentText() + ) + ) + + self.plot_picker_group_layout.addRow(self.plot_picker_submit) + + # estimation group + self.estimate_group = QGroupBox(self.frame) + self.estimate_group_layout = QFormLayout(self.estimate_group) + self.estimate_manually_button = QPushButton(self.estimate_group) + self.estimate_automatically_button = QPushButton(self.estimate_group) + self.render_estimation_group() + + # initialize reduction results screen + self.preview_screen = None + + # automatic reduction group + self.auto_reduction_group = QGroupBox(self.frame) + self.auto_reduction_group.setTitle("Reduce dimensions") + self.auto_reduction_group_layout = QFormLayout(self.auto_reduction_group) + + self.num_dimensions_spinbox = QSpinBox() + self.manual_reduction = QPushButton(self.auto_reduction_group) + self.auto_reduction = QPushButton(self.auto_reduction_group) + + self.num_dimensions_spinbox.setMinimum(1) + self.num_dimensions_spinbox.setValue(1) + self.auto_reduction_group_layout.addRow( + QLabel("Number of dimensions:"), self.num_dimensions_spinbox + ) + + self.manual_reduction_label = QLabelWithTooltip( + "Reduce with fixed number", + "Reduce dimensions using the Principal Component Analysis algorithm.", + ) + self.manual_reduction_label.layout.setAlignment( + Qt.AlignCenter | Qt.AlignVCenter + ) + self.manual_reduction.setLayout(self.manual_reduction_label.layout) + self.manual_reduction.setMinimumHeight(23) + self.auto_reduction_group_layout.addRow(self.manual_reduction) + self.manual_reduction.clicked.connect( + lambda: self.reduce_dimensions(self.num_dimensions_spinbox.value()) + ) + + self.auto_reduction_label = QLabelWithTooltip( + "Reduce dynamically", + "Reduce dimensions using the Principal Component Analysis algorithm.\nTake dimensions with imapct more than 5%.", + ) + self.auto_reduction_label.layout.setAlignment(Qt.AlignCenter | Qt.AlignVCenter) + self.auto_reduction.setLayout(self.auto_reduction_label.layout) + self.auto_reduction.setMinimumHeight(23) + self.auto_reduction_group_layout.addRow(self.auto_reduction) + self.auto_reduction.clicked.connect(lambda: self.reduce_dimensions()) + + # plot stats window + self.plot_widget = QGroupBox(self.frame) + self.plot_widget.setTitle("Plot") + + self.plot_layout = QVBoxLayout() + self.plot_widget.setLayout(self.plot_layout) + + # column rejection group + self.columns_group = QGroupBox(self.frame) + self.columns_group.setTitle("Columns") + self.columns_group_layout = QHBoxLayout(self.columns_group) + + self.scroll_box = QGroupBox(self.frame) + self.columns_group_form_layout = QFormLayout(self.scroll_box) + + self.scroll = QScrollArea() + self.scroll.setWidget(self.scroll_box) + self.scroll.setWidgetResizable(True) + self.scroll.setMinimumHeight(26) + + self.scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.columns_group_layout.addWidget(self.scroll) + self.columns_group.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + self.submit_checkboxes_button = QPushButton() + self.submit_checkboxes_button.setFixedHeight(23) + self.submit_checkboxes_button.setText("Select") + self.submit_checkboxes_button.clicked.connect(lambda: self.submit_columns()) + self.columns_group_layout.addWidget(self.submit_checkboxes_button) + + self.add_columns_to_layout() + + # layouts for sections + layout = QVBoxLayout(self.frame) + + self.plot_picker_column = QVBoxLayout() + self.plot_picker_column.addStretch(1) + self.plot_picker_column.addWidget(self.plot_picker_group) + self.plot_picker_column.addStretch(1) + + self.first_row = QHBoxLayout() + self.first_row.addLayout(self.plot_picker_column, 0) + self.first_row.addWidget(self.plot_widget, 1) + + self.second_row = QHBoxLayout() + self.second_row.addWidget(self.estimate_group, 1) + self.second_row.addWidget(self.auto_reduction_group, 1) + self.second_row.addWidget(self.columns_group, 1) + + layout.addLayout(self.first_row, 3) + layout.addLayout(self.second_row, 1) + + def change_settings(self): + if self.plot_select_box.currentText() == "Scatter plot": + self.column_select_box.setParent(None) + self.column_picker_label.setParent(None) + self.plot_picker_group_layout.insertRow(1, self.parameters_widget) + self.plot_picker_group_layout.insertRow( + 2, self.group_picker_label, self.group_select_box + ) + else: + self.parameters_widget.setParent(None) + self.group_picker_label.setParent(None) + self.group_select_box.setParent(None) + self.plot_picker_group_layout.insertRow( + 1, self.column_picker_label, self.column_select_box + ) + + def activate_scatter_plot(self): + self.parameters_widget_connection = ( + self.parameters_widget.samples_changed.connect( + lambda: self.plot_data( + self.column_select_box.currentText(), + self.plot_select_box.currentText(), + ) + ) + ) + self.parameters_widget.sample_button.setEnabled(True) + self.parameters_widget.sample_box.setEnabled(True) + + def deactivate_scatter_plot(self): + if self.parameters_widget_connection is not None: + self.parameters_widget.samples_changed.disconnect( + self.parameters_widget_connection + ) + self.parameters_widget_connection = None + self.parameters_widget.sample_button.setEnabled(False) + self.parameters_widget.sample_box.setEnabled(False) + + def get_data(self): + """check column names every time coming to that frame (potential changes)""" + if self.engine.state.imported_data is None: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("No dataset was selected") + error.setWindowTitle("Error") + error.exec_() + return + + loading_screen = QSplashScreen() + size = QDesktopWidget().screenGeometry(-1) + loading_screen.showMessage("

Loading...

", Qt.AlignCenter) + loading_screen.setGeometry( + QRect(size.width() // 2 - 125, size.height() // 2 - 50, 250, 100) + ) # hardcoded alignment + loading_screen.show() + QApplication.processEvents() + + self.parent().unfold(self) + self.column_select_box.clear() + self.column_select_box.addItems(self.engine.get_columns()) + self.group_select_box.clear() + self.group_select_box.addItems([""] + list(self.engine.get_columns())) + self._clear_plot() + self.parameters_widget.new_columns_name(self.engine.get_numeric_columns()) + self.parameters_widget.new_size(self.engine.get_size()) + self.add_columns_to_layout() + self.engine.clean_data("cast") + + max_dimensions = self.engine.number_of_numeric_columns() - len( + [ + column + for column in self.engine.state.reduced_columns + if column in self.engine.state.imported_data + ] + ) + self.set_reduction_bounds(max_dimensions) + loading_screen.close() + + def plot_data(self, column_name, plot_type): + self._clear_plot() + scatter_settings = self.parameters_widget.get_parameters() + group_by = self.group_select_box.currentText() + scatter_settings["group_by"] = group_by or None + plot_box = self.create_plot(column_name, plot_type, scatter_settings) + if plot_type == "Scatter plot": + self.activate_scatter_plot() + self.plot_layout.addWidget(plot_box) + + def _clear_plot(self): + self.deactivate_scatter_plot() + for i in reversed(range(self.plot_layout.count())): + figure = self.plot_layout.itemAt(i).widget().figure + plt.close(figure) + self.plot_layout.itemAt(i).widget().setParent(None) + + def add_columns_to_layout(self): + self.clear_column_layout() + columns = self.engine.get_raw_columns() + selected_columns = self.engine.get_columns() + self.mark_reduced_columns = False + for column in columns: + checkbox = QCheckBox(column) + checkbox.setChecked(column in selected_columns) + self.columns_group_form_layout.addRow(checkbox) + + def clear_column_layout(self): + for i in reversed(range(self.columns_group_form_layout.count())): + self.columns_group_form_layout.itemAt(i).widget().setParent(None) + + def submit_columns(self): + self.data_submitted = False + columns = [] + for i in range(self.columns_group_form_layout.count()): + if self.columns_group_form_layout.itemAt(i).widget().isChecked(): + columns.append(self.columns_group_form_layout.itemAt(i).widget().text()) + if not columns: + error = QMessageBox() + error.setIcon(QMessageBox.Critical) + error.setText("No columns were chosen") + error.setWindowTitle("Error") + error.exec_() + return + if self.engine.has_rows_with_nulls(columns): + self.remove_nulls_warning() + else: + self.data_submitted = True + if self.data_submitted: + self.engine.set_state(columns) + self.engine.clean_data("remove") + self.get_data() + + def create_plot(self, column_name, plot_type, scatter_settings): + plotter = None + if column_name == "": + plotter = plots.FallbackPlot([]) + return plotter.plot() + column = self.engine.state.imported_data.loc[:, column_name] + match plot_type: + case "Histogram": + plotter = plots.HistogramPlot(column) + case "Pie": + plotter = plots.PiePlot(column) + case "Null frequency": + plotter = plots.NullFrequencyPlot(column) + case "Scatter plot": + plotter = plots.ScatterPlot( + self.engine.state.imported_data, + scatter_settings, + ) + return plotter.plot() + + def remove_nulls_warning(self): + warning = QMessageBox() + warning.setIcon(QMessageBox.Warning) + warning.setText("Null values in set") + warning.setInformativeText( + "This data contains some empty values. After proceeding some of the rows will be " + "discarded. Continue?" + ) + warning.setWindowTitle("Cleaning data") + warning.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) + warning.buttonClicked.connect(self.handle_warning_click) + warning.exec_() + + def handle_warning_click(self, button): + self.data_submitted = "OK" in button.text() + + def reduce_dimensions(self, dim_number=None): + self.engine.state.imported_data.drop( + self.engine.state.reduced_columns, axis=1, inplace=True, errors="ignore" + ) + self.engine.state.raw_data.drop( + self.engine.state.reduced_columns, axis=1, inplace=True + ) + self.engine.state.reduced_columns = self.engine.reduce_dimensions(dim_number) + self.mark_reduced_columns = True + self.show_reduction_results() + + def set_reduction_bounds(self, max_dimensions): + self.num_dimensions_spinbox.setMinimum(1) + self.num_dimensions_spinbox.setMaximum(max(max_dimensions - 1, 1)) + self.manual_reduction.setDisabled(max_dimensions < 2) + self.auto_reduction.setDisabled(max_dimensions < 2) + + def show_reduction_results(self): + self.preview_screen = DataPreviewScreen( + self, title="Reduction results", reason=PreviewReason.REDUCTION + ) + self.preview_screen.show() + + def render_estimation_group(self): + self.estimate_group.setTitle("Fill missing values") + + self.estimate_manually_button.setText("Enter manually") + self.estimate_manually_button.setMinimumHeight(23) + self.estimate_group_layout.addRow(self.estimate_manually_button) + self.estimate_manually_button.clicked.connect(self.estimate_manually) + + self.estimate_automatically_button.setText("Mean/mode estimation") + self.estimate_automatically_button.setMinimumHeight(23) + self.estimate_group_layout.addRow(self.estimate_automatically_button) + self.estimate_automatically_button.clicked.connect( + self.estimate_with_mean_or_mode + ) + + def estimate_manually(self): + self.preview_screen = DataPreviewScreen( + self, title="Input missing values", reason=PreviewReason.ESTIMATION + ) + self.preview_screen.show() + + def estimate_with_mean_or_mode(self): + self.engine.mean_or_mode_estimate() + self.preview_screen = DataPreviewScreen( + self, title="Estimation results", reason=PreviewReason.PREVIEW + ) + self.preview_screen.show() diff --git a/src/widgets/unfold_widgets/results_widget.py b/src/widgets/unfold_widgets/results_widget.py new file mode 100644 index 0000000..44810ae --- /dev/null +++ b/src/widgets/unfold_widgets/results_widget.py @@ -0,0 +1,46 @@ +from PyQt5.QtWidgets import QHBoxLayout, QTabWidget + +from widgets import UnfoldWidget + + +class ResultsWidget(UnfoldWidget): + def __init__(self, parent, engine): + super().__init__(parent, engine, "results_widget", "RESULTS") + self.button.disconnect() + self.button.clicked.connect(self.load_widget) + self.engine = engine + + # algorithm results tab widget + self.results_tab_widget = QTabWidget(self) + + # layout setup + self.layout = QHBoxLayout(self.frame) + self.layout.addWidget(self.results_tab_widget) + + def load_widget(self): + self.parent().unfold(self) + + for i in reversed(range(self.results_tab_widget.count())): + self.results_tab_widget.removeTab(i) + + if self.engine.state.last_algorithm is not None: + last_technique, last_algorithm = self.engine.state.last_algorithm + else: + last_technique = last_algorithm = "" + + for ( + technique, + algorithms, + ) in self.engine.state.algorithm_results_widgets.items(): + for algorithm, results in algorithms.items(): + algorithm_result_tab_widget = QTabWidget() + for i, result_widget in enumerate(results): + algorithm_result_tab_widget.addTab(result_widget, f"{i+1}") + algorithm_idx = self.results_tab_widget.addTab( + algorithm_result_tab_widget, f"{technique}: {algorithm}" + ) + if technique == last_technique and algorithm == last_algorithm: + self.results_tab_widget.setCurrentIndex(algorithm_idx) + self.results_tab_widget.widget(algorithm_idx).setCurrentIndex( + len(results) - 1 + ) diff --git a/src/widgets/unfold_widgets/unfold_widget.py b/src/widgets/unfold_widgets/unfold_widget.py new file mode 100644 index 0000000..2c182bd --- /dev/null +++ b/src/widgets/unfold_widgets/unfold_widget.py @@ -0,0 +1,34 @@ +from PyQt5.QtWidgets import QFrame, QHBoxLayout, QSizePolicy, QWidget + +from widgets import UNFOLD_BUTTON_WIDTH +from widgets.rotated_button import RotatedButton + + +class UnfoldWidget(QWidget): + def __init__(self, parent: QWidget, engine, object_id: str, button_text: str): + super().__init__(parent) + + self.engine = engine + self.setObjectName(object_id) + self.setFixedWidth(UNFOLD_BUTTON_WIDTH) + + # unfold button + self.button = RotatedButton(self) + self.button.setFixedWidth(UNFOLD_BUTTON_WIDTH) + self.button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + self.button.clicked.connect(lambda: self.parent().unfold(self)) + self.button.setText(button_text) + + # main frame + self.frame = QFrame(self) + self.frame.setFixedWidth(0) + + # layout + layout = QHBoxLayout() + layout.addWidget(self.button) + layout.addWidget(self.frame) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + self.setLayout(layout) + self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) diff --git a/static/css/styles.css b/static/css/styles.css new file mode 100644 index 0000000..46e8391 --- /dev/null +++ b/static/css/styles.css @@ -0,0 +1,57 @@ +UnfoldWidget > QFrame { + background-color: white; +} + +QWidget QGroupBox:title, QWidget QTabBar::tab{ + padding: 5px; + border-radius: 5px; + margin-bottom: 5px; + margin-right: 5px; +} + +QScrollArea { background: transparent; } +QScrollArea QWidget { background: transparent; } + +#import_widget > QPushButton, #import_widget QGroupBox:title, #data_generator_widget > QGroupBox:title { + background-color: #054a91; + color: white; + border: none; +} + +#preprocessing_widget > QPushButton, #preprocessing_widget QGroupBox:title { + background-color: #3e7cb1; + color: white; + border: none; +} + +#algorithm_setup_widget > QPushButton, #algorithm_setup_widget QGroupBox:title{ + background-color: #81a4cd; + color: white; + border: none; +} + +#algorithm_run_widget > QPushButton, #algorithm_run_widget QGroupBox:title{ + background-color: #dbe4ee; + color: black; + border: none; +} + +#results_widget > QPushButton, #results_widget QGroupBox:title, #results_widget QTabBar::tab:selected { + background-color: #f17300; + color: white; + border: none; +} + +#results_widget QTabWidget QTabWidget QGroupBox QGroupBox { + border-radius: 10px; +} + +MergingSetsScreen QGroupBox:title{ + padding: 5px; + border-radius: 5px; + margin-bottom: 5px; + margin-right: 5px; + background-color: #054a91; + color: white; + border: none; +} \ No newline at end of file diff --git a/static/img/info_icon.svg b/static/img/info_icon.svg new file mode 100644 index 0000000..1196cac --- /dev/null +++ b/static/img/info_icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/database_tests/__init__.py b/test/database_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/database_tests/test_database_reader.py b/test/database_tests/test_database_reader.py new file mode 100644 index 0000000..a31386b --- /dev/null +++ b/test/database_tests/test_database_reader.py @@ -0,0 +1,40 @@ +import os +import sys +from unittest import TestCase # if authorization fails, set env variable MONGO_PASS + +import pandas as pd + +sys.path.append(os.path.abspath("..")) +sys.path.append(os.path.abspath("../../src")) +from src.database import DocumentRemover, Reader, Writer + + +class TestReader(TestCase): + def setUp(self) -> None: + self.reader = Reader("test", "test") + self.query_test_result = [ + {"key": "value", "another_key": 1}, + {"key": "value", "another_key": 1}, + {"key": "value", "another_key": 1}, + {"key": "value", "another_key": 1}, + ] + DocumentRemover("test", "test").remove_all() + writer = Writer("test", "test") + for i in range(4): + writer.add_document({"key": "value", "another_key": 1}) + + def test_execute_query(self): + self.assertEqual( + self.reader.execute_query(columns=["key", "another_key"]), + self.query_test_result, + ) + + def test_get_nth_chunk(self): + df = pd.DataFrame(self.query_test_result) + self.assertDictEqual(df.to_dict(), self.reader.get_nth_chunk().to_dict()) + + def test_get_rows_number(self): + self.assertEqual(self.reader.get_rows_number(), 4) + + def test_get_columns_names(self): + self.assertEqual(self.reader.get_columns_names(), ["_id", "key", "another_key"])