# import logging
from dataclasses import dataclass
from functools import cached_property
from typing import List, Tuple, NamedTuple, Dict, Optional

import numpy as np

from PySide2.QtGui import QTransform
from .utils import Pnt, DictPnt, MatchingPoints

Errors = Tuple[float, ...]


class AffineTransformParams(NamedTuple):
    a: float
    b: float
    c: float
    d: float
    e: float
    f: float


class QTransformCalculator:
    @dataclass
    class Result:
        affine_transform_params: AffineTransformParams
        errors_dict: Dict[int, float]

    def __init__(self, points_table: DictPnt, points_map: DictPnt) -> None:
        # logging.debug(f"transformer input:")
        # logging.debug(f"{points_table=}")
        # logging.debug(f"{points_map=}")

        matching_points_list: List[MatchingPoints] = []

        errors_indexes: List[int] = []
        # n = min(len(points_table), len(points_map))

        for i in points_table.keys():
            src_pnt = points_table.get(i, None)
            dst_pnt = points_map.get(i, None)
            if not all((src_pnt, dst_pnt)):
                continue
            mp = MatchingPoints(src_pnt, dst_pnt)
            matching_points_list.append(mp)
            errors_indexes.append(i)

        source_points = []
        target_points = []
        for matching_points in matching_points_list:
            source_points.append(matching_points.src_pnt)
            target_points.append(matching_points.dst_pnt)

        self._source_points = source_points
        self._target_points = target_points
        self._errors_indexes = errors_indexes

    def calculate(self) -> Optional[QTransform]:
        if len(self.source_points) < 3:
            return None

        a, b, c, d, e, f = self.affine_transform_params
        q_transform_params = {
            "h11": a,
            "h12": b,
            "h13": d,

            "h21": e,
            "h22": c,
            "h23": f,

            # "h31": 0,
            # "h32": 0,
            # "h33": 1,
        }

        q_transform = QTransform(*q_transform_params.values())
        return q_transform

    @property
    def source_points(self) -> List[Pnt]:
        return self._source_points

    @property
    def target_points(self) -> List[Pnt]:
        return self._target_points

    @property
    def errors_indexes(self) -> List[int]:
        return self._errors_indexes

    def transform(self) -> Optional[Result]:
        if len(self.source_points) < 4:
            return None
        errors_dict = dict(zip(self.errors_indexes, self.calculate_errors()))
        result = self.Result(self.affine_transform_params, errors_dict)
        # logging.debug(f"transformer output:")
        # logging.debug(f"{result=}")
        return result

    @cached_property
    def affine_transform_params(self) -> AffineTransformParams:
        source_points, target_points = self.source_points, self.target_points
        # Подготовка матрицы A и вектора b для решения системы уравнений
        # noinspection PyPep8Naming
        A = np.zeros((len(source_points) * 2, 6))
        b = np.zeros(len(source_points) * 2)

        for i in range(len(source_points)):
            x1, y1 = source_points[i]
            x2, y2 = target_points[i]

            A[i * 2] = [x1, y1, 1, 0, 0, 0]
            A[i * 2 + 1] = [0, 0, 0, x1, y1, 1]

            b[i * 2] = x2
            b[i * 2 + 1] = y2

        # Решение системы уравнений методом наименьших квадратов
        params = np.linalg.lstsq(A, b, rcond=None)[0]
        return AffineTransformParams(*params)

    def calculate_errors(self) -> Errors:
        return tuple(
            self.calculate_single_error(source_point, target_point)
            for source_point, target_point in zip(self.source_points, self.target_points)
        )

    def calculate_single_error(self, source_pnt: Pnt, target_point: Pnt) -> float:
        (x1, y1), (x2, y2) = source_pnt, target_point
        a, b, c, d, e, f = self.affine_transform_params

        # Рассчитываем координаты точки в целевой системе
        x2_predicted = a * x1 + b * y1 + c
        y2_predicted = d * x1 + e * y1 + f

        # Рассчитываем ошибку как евклидово расстояние между предсказанными и фактическими координатами
        result: np.ndarray = np.sqrt((x2_predicted - x2) ** 2 + (y2_predicted - y2) ** 2)

        return float(result)
