# Copyright (C) 2022 The Qt Company Ltd.
# SPDX-License-Identifier: LicenseRef-Qt-Commercial

import warnings
from enum import Enum
from functools import partial
from typing import Callable, List, Optional, Tuple

from PySide6.QtCore import QObject, Qt, Slot
from PySide6.QtGui import QStandardItem, QStandardItemModel
from PySide6.QtXml import QDomDocument, QDomElement, QDomNode


class ItemType(Enum):
    NAMESPACE = 1
    OBJECT = 2
    VALUE = 3
    ENUM = 4
    FUNCTION = 5


class NodeType(Enum):
    XML_COMMENT = 0
    OTHER = 1
    OTHER_ELEMENT = 2
    CODE_ELEMENT = 3
    DISCARD_CODE_ELEMENT = 4


class RecursionMode(Enum):
    ALL = 0
    CHECKED = 1
    UNCHECKED = 2


ELEMENT_TYPE = {'namespace-type': ItemType.NAMESPACE,
                'object-type': ItemType.OBJECT,
                'value-type': ItemType.VALUE,
                'enum-type': ItemType.ENUM,
                'function': ItemType.FUNCTION}


ITEM_DESCRIPTION = {ItemType.NAMESPACE: "Namespace",
                    ItemType.OBJECT: "Object",
                    ItemType.VALUE: "Value",
                    ItemType.ENUM: "Enumeration",
                    ItemType.FUNCTION: "Function"}


EXCLUDED_QT_CLASSES = [
    'QAtomicOpsSupport', 'QArrayData', 'QChar',
    'QCharRef', 'QAtomicWindowsType', 'QByteArrayRef', 'QByteArrayDataPtr',
    'QByteRef', 'QFlag', 'QGenericArgument', 'QGenericReturnArgument',
    'QIncompatibleFlag', 'QIntegerForSize', 'QInternal', 'QLatin1String',
    'QLatin1Char', 'QListData', 'QListSpecialMethods', 'QMetaObject',
    'QMetaType', 'QMetaTypeId', 'QMetaTypeId2', 'QObjectData',
    'QObjectUserData', 'QScopedPointerPodDeleter', 'QStringDataPtr',
    'QStringRef', 'QTypeInfo']


def _clone_document(source: QDomDocument) -> QDomDocument:
    """DOM helpers: Clone a DOM document"""
    result = QDomDocument(source.doctype())
    root = source.documentElement()
    result.appendChild(result.importNode(root, True))
    return result


def _child_elements(dom_element: QDomElement) -> List[QDomElement]:
    """DOM helpers: Return child elements of a DOM node"""
    result = []
    children = dom_element.childNodes()
    for i in range(children.count()):
        child = children.at(i)
        if child.isElement():
            result.append(child.toElement())
    return result


def _item_type(dom_element: QDomElement) -> Optional[Tuple[ItemType, str]]:
    """Return a tuple of ItemType/name for code model DOM nodes"""
    type = ELEMENT_TYPE.get(dom_element.tagName())
    if not type:
        return None
    attr_name = 'name' if type != ItemType.FUNCTION else 'signature'
    return (type, dom_element.attribute(attr_name)) if type else None


def _discard_function(signature: str) -> bool:
    """ Helper for CodeModel.parse(): Returns whether function is internal."""
    return (signature.startswith('qt_getEnumMetaObject(')
            or signature.startswith('qt_getEnumName('))


def _node_type(dom_node: QDomNode) -> NodeType:
    """Helper for CodeModel.parse(): Return NodeType for DOM nodes"""
    if dom_node.isComment():
        return NodeType.XML_COMMENT
    if not dom_node.isElement():
        return NodeType.OTHER
    item_type = _item_type(dom_node.toElement())
    if not item_type:
        return NodeType.OTHER_ELEMENT
    # Check for uninteresting nodes, private classes, etc.
    if 'Private' in item_type[1]:
        return NodeType.DISCARD_CODE_ELEMENT
    if item_type[0] == ItemType.NAMESPACE and item_type[1] == 'std':
        return NodeType.DISCARD_CODE_ELEMENT
    if ((item_type[0] == ItemType.OBJECT or item_type[0] == ItemType.VALUE)
            and item_type[1] in EXCLUDED_QT_CLASSES):
        return NodeType.DISCARD_CODE_ELEMENT
    if (item_type[0] == ItemType.FUNCTION and _discard_function(item_type[1])):
        return NodeType.DISCARD_CODE_ELEMENT
    return NodeType.CODE_ELEMENT


def _remove_function_nodes(dom_element: QDomElement) -> None:
    """Helper for CodeModel.parse(): Remove function nodes (global)."""
    for child in _child_elements(dom_element):
        if child.tagName() == 'function':
            dom_element.removeChild(child)


def _clean_dom_tree_recursion(dom_element: QDomElement) -> None:
    """Helper for CodeModel.parse(): Recurse down DOM, removing all private
       classes and internal functions, XML comments.
       Order code nodes alphabetically."""
    def code_node_sort_key(code_element):
        return _item_type(code_element)[1]

    code_elements = []
    children = dom_element.childNodes()
    # Clean out nodes
    for i in reversed(range(children.count())):
        child_node = children.at(i)
        node_type = _node_type(child_node)
        if (node_type == NodeType.XML_COMMENT
                or node_type == NodeType.DISCARD_CODE_ELEMENT):
            dom_element.removeChild(child_node)
        elif node_type != NodeType.OTHER:
            child = child_node.toElement()
            _clean_dom_tree_recursion(child)
            if node_type == NodeType.CODE_ELEMENT:
                code_elements.append(child)
    # Re-append code elements in sorted order
    code_elements.sort(key=code_node_sort_key)
    for c in code_elements:
        dom_element.removeChild(c)
        dom_element.appendChild(c)


def _remove_unselected_recursion(stack: List[str],
                                 dom_element: QDomElement,
                                 unselected: List[str]) -> None:
    """Helper for CodeModel.typesystem(): Recurse over DOM nodes
       and remove all unselected branches."""
    for child in _child_elements(dom_element):
        type = _item_type(child)
        remove_node = False
        if type:
            stack.append(type[1])
            qualified_name = '.'.join(stack)
            remove_node = qualified_name in unselected
        if remove_node:
            dom_element.removeChild(child)
        else:
            _remove_unselected_recursion(stack, child, unselected)
        if type:
            del stack[-1]


class CodeModel(QStandardItemModel):
    """Represents TypeSystem XML as a hierarchical model
       First column: Description (Class/Enum,...), user data: ItemType
       Second column: Name, user data: Qualified name
    """

    def __init__(self, parent: QObject = None):
        super(CodeModel, self).__init__(0, 2, parent)
        self.setHorizontalHeaderLabels(["Type", "Name"])
        self.setHeaderData(0, Qt.Horizontal, Qt.AlignCenter, Qt.TextAlignmentRole)
        self._dom_document = None

    def clear(self) -> None:
        self.removeRows(0, self.rowCount())

    def is_empty(self) -> bool:
        return self.rowCount() == 0

    def parse(self, xml: str, is_raw: bool = True) -> bool:
        """Parse a typesystem tree and populate the model. is_raw means it was
           obtained from the dump tool and needs to cleaned (remove
           inapplicable nodes, comments. etc)."""
        dom_document = QDomDocument()
        if not dom_document.setContent(xml, True):
            warnings.warn('XML Parse error')
            return False
        if is_raw:
            root = dom_document.documentElement()
            # Remove top level functions
            _remove_function_nodes(root)
            _clean_dom_tree_recursion(root)
        self._set_document(dom_document)
        return True

    def _set_document(self, dom_document: QDomDocument) -> None:
        """Sets a (preprocessed) DOM document)"""
        self.clear()
        self._dom_document = dom_document
        parent_stack: List[QStandardItem] = []
        name_stack: List[str] = []
        self._population_recursion(parent_stack, name_stack,
                                   dom_document.documentElement())

    def _population_recursion(self,
                              parent_stack: List[QStandardItem],
                              name_stack: List[str],
                              dom_element: QDomElement) -> None:
        """Recursion helper for set_document(): Create the model items."""
        node_type = _item_type(dom_element)
        if node_type:
            name = node_type[1]
            name_stack.append(name)
            qualified_name = '.'.join(name_stack)
            new_row = self._create_row(node_type[0], name, qualified_name)
            if parent_stack:
                parent_stack[-1].appendRow(new_row)
            else:
                self.appendRow(new_row)
            parent_stack.append(new_row[0])

        for child in _child_elements(dom_element):
            self._population_recursion(parent_stack, name_stack, child)
        if node_type:
            del parent_stack[-1]
            del name_stack[-1]

    def _create_row(self, type: ItemType, name: str,
                    qualified_name: str) -> List[QStandardItem]:
        """Helper for parse(): Create a row of items."""
        flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled
        name_item = QStandardItem(name)
        name_item.setFlags(flags)
        name_item.setData(qualified_name)
        description_item = QStandardItem(ITEM_DESCRIPTION[type])
        description_item.setFlags(flags | Qt.ItemIsUserCheckable)
        description_item.setData(type)
        description_item.setCheckable(True)
        description_item.setCheckState(Qt.Checked)
        return [description_item, name_item]

    @staticmethod
    def _do_recurse(recursion_mode: RecursionMode,
                    desc_item: QStandardItem) -> bool:
        """Helper for _recurse_model(): check whether to recurse down a node
           matching the policy"""
        if recursion_mode == RecursionMode.ALL:
            return True
        checked = desc_item.checkState() == Qt.Checked
        return (recursion_mode == RecursionMode.CHECKED if checked else
                recursion_mode == RecursionMode.UNCHECKED)

    def _model_recursion(self, recursion_mode: RecursionMode,
                         desc_item: QStandardItem,
                         name_item: QStandardItem,
                         func: Callable) -> None:
        """Helper for _recurse_model(): recurse down the QStandardItems"""
        func(desc_item, name_item)
        for i in range(desc_item.rowCount()):
            child_desc_item = desc_item.child(i, 0)
            if CodeModel._do_recurse(recursion_mode, desc_item):
                child_name_item = desc_item.child(i, 1)
                self._model_recursion(recursion_mode, child_desc_item,
                                      child_name_item, func)

    def _recurse_model(self, recursion_mode: RecursionMode,
                       func: Callable) -> None:
        """Recurse down model and call func with 2 items"""
        for i in range(self.rowCount()):
            desc_item = self.item(i, 0)
            if CodeModel._do_recurse(recursion_mode, desc_item):
                name_item = self.item(i, 1)
                self._model_recursion(recursion_mode, desc_item, name_item,
                                      func)

    def selected_classes(self) -> List[str]:
        """Returns the selected classes (for cmake lists)"""
        def collect_selected_classes(result, desc_item, name_item):
            type = desc_item.data()
            if (type == ItemType.NAMESPACE or type == ItemType.OBJECT
                    or type == ItemType.VALUE):
                result.append(name_item.data())
        result: List[str] = []
        func = partial(collect_selected_classes, result)
        self._recurse_model(RecursionMode.CHECKED, func)
        result.sort()
        return result

    def unselected_nodes(self) -> List[str]:
        """Returns the unselected nodes (all types)"""
        def add_to_list(list, desc_item, name_item):
            list.append(name_item.data())
        result: List[str] = []
        func = partial(add_to_list, result)
        self._recurse_model(RecursionMode.UNCHECKED, func)
        result.sort()
        return result

    def set_unselected_nodes(self, unselected: List[str]) -> None:
        def apply_selection(unselected, desc_item, name_item):
            if name_item.data() in unselected:
                desc_item.setCheckState(Qt.Unchecked)
        func = partial(apply_selection, unselected)
        self._recurse_model(RecursionMode.ALL, func)

    @Slot()
    def select_all(self) -> None:
        def select(desc_item, name_item):
            desc_item.setCheckState(Qt.Checked)
        self._recurse_model(RecursionMode.ALL, select)

    @Slot()
    def unselect_all(self) -> None:
        def unselect(desc_item, name_item):
            desc_item.setCheckState(Qt.Unchecked)
        self._recurse_model(RecursionMode.ALL, unselect)

    def typesystem(self) -> QDomDocument:
        """Return the Typesystem DOM document with the
           unselected nodes removed."""
        result = _clone_document(self._dom_document)
        unselected = self.unselected_nodes()
        if unselected:
            stack: List[str] = []
            _remove_unselected_recursion(stack, result.documentElement(),
                                         unselected)
        return result

    def save(self, data: dict):
        """Save to a dict"""
        if not self.is_empty():
            data['typesystem'] = self._dom_document.toString()
            data['unselected'] = self.unselected_nodes()

    def load(self, data: dict):
        """Load from a dict"""
        typesystem = data.get('typesystem')
        if not typesystem:
            self.clear()
            return
        self.parse(typesystem, False)
        unselected = data.get('unselected')
        if unselected:
            self.set_unselected_nodes(unselected)
