#!/usr/bin/env python3

# Libervia plugin for Data Form (XEP-0004)
# Copyright (C) 2009-2025 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Self, TYPE_CHECKING, Type, TypeVar
from typing import Annotated, Literal

from pydantic import BaseModel, Field
from twisted.words.protocols.jabber import jid
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import data_form, disco
from zope.interface import implementer

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.models.types import DomishElementType, JIDType

if TYPE_CHECKING:
    from libervia.backend.core.main import LiberviaBackend

log = getLogger(__name__)

PLUGIN_INFO = {
    C.PI_NAME: "Data Forms",
    C.PI_IMPORT_NAME: "XEP-0004",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_TYPE: "XEP",
    C.PI_PROTOCOLS: ["XEP-0004", "XEP-0315"],
    C.PI_MAIN: "XEP_0004",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: _(
        "Implementation of data forms.\n"
        "This plugins extends Wokkel implementation with Pydantic models.\n"
        # XXX: XEP-0315 is currently implemented in "sat_tmp.wokkel.data_form"
        "XEP-0315 (Data Forms XML Element) is also supported here."
    ),
}

NS_XML_ELEMENT = "urn:xmpp:xml-element"


# Map from value to optional label.
class Option(BaseModel):
    value: str
    label: str | None = None


class BaseField(BaseModel):
    var: str
    label: str | None = None
    description: str | None = Field(alias="desc", default=None)
    required: bool = False


class Text(BaseField):
    type: Literal["text-single", "text-private", "text-multi", "fixed"] = "text-single"
    value: str | None = None


class Boolean(BaseField):
    type: Literal["boolean"] = "boolean"
    value: bool | None = None


class Hidden(BaseField):
    type: Literal["hidden"] = "hidden"
    value: str | None = None


class JIDSingle(BaseField):
    type: Literal["jid-single"] = "jid-single"
    value: JIDType | None = None


class JIDMulti(BaseField):
    type: Literal["jid-multi"] = "jid-multi"
    values: list[JIDType] = Field(default_factory=list)


class ListSingle(BaseField):
    type: Literal["list-single"] = "list-single"
    value: str | None = None
    options: list[Option] = Field(default_factory=list)


class ListMulti(BaseField):
    type: Literal["list-multi"] = "list-multi"
    values: list[str] = Field(default_factory=list)
    options: list[Option] = Field(default_factory=list)


class XML(BaseField):
    type: Literal["xml"] = "xml"
    value: DomishElementType | None = None


DataField = Annotated[
    Text | Boolean | Hidden | JIDSingle | JIDMulti | ListSingle | ListMulti | XML,
    Field(discriminator="type"),
]

T = TypeVar("T", bound=DataField)


class DataForm(BaseModel):
    type: Literal["form", "submit", "result", "cancel"] = "form"
    title: str | None = None
    instructions: str | None = None
    namespace: str | None = None
    fields: list[DataField] = Field(default_factory=list)

    def get_values(self) -> dict[str, list[None | str | bool | jid.JID | domish.Element]]:
        """Return a dict mapping from ``var`` to a list of values.

        This is typically used to retrieve submitted values.
        """
        return {
            field.var: (getattr(field, "values", None) or [field.value])  #  type: ignore
            for field in self.fields
            if field.var
        }

    def __getitem__(self, field_var: str) -> DataField | None:
        for field in self.fields:
            if field.var == field_var:
                return field

    def get_field(self, field_var: str, expected_type: Type[T]) -> T:
        """Retrieve a field with static type verification.

        This method solves linter complaints by providing static type guarantees.
            Always use this when you know the expected field type.
            Example::

                text_field = form.get_field("username", Text)
                username = text_field.value

        @param field_var: Field var to retrieve.
        @param expected_type: Expected field type (Text, Boolean, etc.).
        @return: Field instance strictly typed as expected_type.
        @raise KeyError: If field doesn't exist.
        @raise TypeError: If field exists but is wrong type.

        """
        for field in self.fields:
            if field.var == field_var:
                if not isinstance(field, expected_type):
                    raise TypeError(
                        f"Field {field_var!r} is {type(field).__name__} "
                        f"(XEP-0004 type: {getattr(field, 'type', 'unknown')}), "
                        f"expected {expected_type.__name__}"
                    )
                return field

        raise KeyError(f"Field '{field_var}' not found in data form")

    @classmethod
    def from_wokkel_form(
        cls, form: data_form.Form, source_form: Self | None = None
    ) -> Self:
        """Convert Wokkel Form to DataForm.

        @param form: Wokkel Form to convert.
        @param source_form: Source DataForm, used to get field types on submitted form.
            This is the original form the sending entity is expected to have filled, its
            field type are used to correctly set the one from submitted form (if they are
            not specified already).
        @return: Instance of DataForm.

        """
        fields = []

        for field in form.fieldList:
            if field.var == "FORM_TYPE" and field.fieldType == "hidden":
                #  We skip FORM_TYPE field which is handled separately.
                continue

            if (
                field.fieldType is None
                and field.var is not None
                and form.formType == "submit"
            ):
                # If on a submitted form the field types are not specified, we try to
                # retrieve them from source form, if any.
                if source_form is not None:
                    for source_field in source_form.fields:
                        if source_field.var == field.var:
                            field.fieldType = source_field.type
                            break

                # If we have no source field, we default to ``list-multi`` because it has
                # multiple text values, which can then be retrieved with
                # ``get_fields_map``.
                if field.fieldType is None:
                    field.fieldType = "list-multi"

            field_dict = {
                "var": field.var or "",
                "label": field.label,
                "desc": field.desc,
                "required": field.required,
            }

            # Handle options for list fields.
            if field.fieldType in ("list-single", "list-multi") and field.options:
                field_dict["options"] = [
                    Option(value=opt.value, label=opt.label) for opt in field.options
                ]

            try:
                match field.fieldType:
                    case "text-single" | "text-private" | "fixed":
                        fields.append(
                            Text(type=field.fieldType, value=field.value, **field_dict)
                        )
                    case "text-multi":
                        fields.append(
                            Text(
                                type="text-multi",
                                value="\n".join(field.values),
                                **field_dict,
                            )
                        )
                    case "boolean":
                        fields.append(Boolean(value=field.value, **field_dict))
                    case "hidden":
                        fields.append(Hidden(value=field.value, **field_dict))
                    case "jid-single":
                        fields.append(JIDSingle(value=field.value, **field_dict))
                    case "jid-multi":
                        fields.append(JIDMulti(values=field.values, **field_dict))
                    case "list-single":
                        fields.append(ListSingle(value=field.value, **field_dict))
                    case "list-multi":
                        fields.append(ListMulti(values=field.values, **field_dict))
                    case None if field.ext_type == "xml":
                        fields.append(XML(value=field.value, **field_dict))
                    case _:
                        # Fallback to text-single for unknown types.
                        fields.append(
                            Text(
                                value="\n".join(field.values) if field.values else "",
                                **field_dict,
                            )
                        )
            except Exception:
                # Fallback for any validation issues.
                fields.append(
                    Text(
                        value="\n".join(field.values) if field.values else "",
                        **field_dict,
                    )
                )

        return cls(
            type=form.formType,
            title=form.title,
            instructions="\n".join(form.instructions) if form.instructions else None,
            namespace=form.formNamespace,
            fields=fields,
        )

    def to_wokkel_form(self) -> data_form.Form:
        """Convert DataForm to Wokkel Form."""
        form = data_form.Form(
            formType=self.type,
            title=self.title,
            instructions=self.instructions.splitlines() if self.instructions else None,
            formNamespace=self.namespace,
        )

        for field in self.fields:
            kwargs = field.model_dump(by_alias=True)
            # We don't want to alias each indivial type with "fieldType", so we change the
            # name of the key here.

            kwargs["fieldType"] = kwargs.pop("type")
            if kwargs["fieldType"] == "text-multi":
                if value := kwargs.pop("value"):
                    kwargs["values"] = value.splitlines()
            if "options" in kwargs:
                kwargs["options"] = [
                    data_form.Option(value=opt["value"], label=opt["label"])
                    for opt in kwargs["options"]
                ]

            form.addField(data_form.Field(**kwargs))

        return form

    @classmethod
    def from_element(
        cls, element: domish.Element, source_form: Self | None = None
    ) -> Self:
        """Parse a data form to create an instance of DataForm.

        If the element itself is not a data form, a data form will be looked after in its
        descendant.
        @param element: Element to parse, it must be a data form, or have a child which is
            a data form.
        @param source_form: Original form for submitted forms. See [from_wokkel_form].
        @return: Instance of DataForm.
        @raise exceptions.NotFound: No data form found in ``element`` or one of its
            descendants.
        """
        if (element.uri, element.name) != ((data_form.NS_X_DATA, "x")):
            try:
                element = next(element.elements(data_form.NS_X_DATA, "x"))
            except StopIteration:
                raise exceptions.NotFound(
                    f"No data form found in {element} or its descendants."
                )
        form = data_form.Form.fromElement(element)
        return cls.from_wokkel_form(form, source_form=source_form)

    def to_element(self) -> domish.Element:
        """Generate the <x> element corresponding to this form."""
        return self.to_wokkel_form().toElement()


class XEP_0004:

    def __init__(self, host: "LiberviaBackend") -> None:
        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
        self.host = host
        host.register_namespace("x-data", data_form.NS_X_DATA)

    def get_handler(self, client):
        return XEP_0004_handler(self)


@implementer(disco.IDisco)
class XEP_0004_handler(XMPPHandler):
    """Handler for pubsub extended discovery requests."""

    def __init__(self, plugin_parent: XEP_0004) -> None:
        self.plugin_parent = plugin_parent

    def getDiscoInfo(
        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoFeature]:
        """Get disco info for pubsub extended discovery

        @param requestor: JID of the requesting entity
        @param target: JID of the target entity
        @param nodeIdentifier: optional node identifier
        @return: list of disco features
        """
        return [
            disco.DiscoFeature(data_form.NS_X_DATA),
            disco.DiscoFeature(NS_XML_ELEMENT),
        ]

    def getDiscoItems(
        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoItem]:
        """Get disco items with extended discovery support

        @param requestor: JID of the requesting entity
        @param target: JID of the target entity
        @param nodeIdentifier: optional node identifier
        @return: list of disco items
        """
        return []
