# Copyright 2025, Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Asynchronous bi-directional streaming RPC helpers."""

import asyncio
import logging
from typing import Callable, Optional, Union

from grpc import aio

from google.api_core import exceptions
from google.api_core.bidi_base import BidiRpcBase

from google.protobuf.message import Message as ProtobufMessage


_LOGGER = logging.getLogger(__name__)


class _AsyncRequestQueueGenerator:
    """_AsyncRequestQueueGenerator is a helper class for sending asynchronous
      requests to a gRPC stream from a Queue.

    This generator takes asynchronous requests off a given `asyncio.Queue` and
    yields them to gRPC.

    It's useful when you have an indeterminate, indefinite, or otherwise
    open-ended set of requests to send through a request-streaming (or
    bidirectional) RPC.

    Example::

        requests = _AsyncRequestQueueGenerator(q)
        call = await stub.StreamingRequest(requests)
        requests.call = call

        async for response in call:
            print(response)
            await q.put(...)

    Args:
        queue (asyncio.Queue): The request queue.
        initial_request (Union[ProtobufMessage,
                Callable[[], ProtobufMessage]]): The initial request to
            yield. This is done independently of the request queue to allow for
            easily restarting streams that require some initial configuration
            request.
    """

    def __init__(
        self,
        queue: asyncio.Queue,
        initial_request: Optional[
            Union[ProtobufMessage, Callable[[], ProtobufMessage]]
        ] = None,
    ) -> None:
        self._queue = queue
        self._initial_request = initial_request
        self.call: Optional[aio.Call] = None

    def _is_active(self) -> bool:
        """Returns true if the call is not set or not completed."""
        # Note: there is a possibility that this starts *before* the call
        # property is set. So we have to check if self.call is set before
        # seeing if it's active. We need to return True if self.call is None.
        # See https://github.com/googleapis/python-api-core/issues/560.
        return self.call is None or not self.call.done()

    async def __aiter__(self):
        # The reason this is necessary is because it lets the user have
        # control on when they would want to send requests proto messages
        # instead of sending all of them initially.
        #
        # This is achieved via asynchronous queue (asyncio.Queue),
        # gRPC awaits until there's a message in the queue.
        #
        # Finally, it allows for retrying without swapping queues because if
        # it does pull an item off the queue when the RPC is inactive, it'll
        # immediately put it back and then exit. This is necessary because
        # yielding the item in this case will cause gRPC to discard it. In
        # practice, this means that the order of messages is not guaranteed.
        # If preserving order is necessary it would be easy to use a priority
        # queue.
        if self._initial_request is not None:
            if callable(self._initial_request):
                yield self._initial_request()
            else:
                yield self._initial_request

        while True:
            item = await self._queue.get()

            # The consumer explicitly sent "None", indicating that the request
            # should end.
            if item is None:
                _LOGGER.debug("Cleanly exiting request generator.")
                return

            if not self._is_active():
                # We have an item, but the call is closed. We should put the
                # item back on the queue so that the next call can consume it.
                await self._queue.put(item)
                _LOGGER.debug(
                    "Inactive call, replacing item on queue and exiting "
                    "request generator."
                )
                return

            yield item


class AsyncBidiRpc(BidiRpcBase):
    """A helper for consuming a async bi-directional streaming RPC.

    This maps gRPC's built-in interface which uses a request iterator and a
    response iterator into a socket-like :func:`send` and :func:`recv`. This
    is a more useful pattern for long-running or asymmetric streams (streams
    where there is not a direct correlation between the requests and
    responses).

    Example::

        initial_request = example_pb2.StreamingRpcRequest(
            setting='example')
        rpc = AsyncBidiRpc(
            stub.StreamingRpc,
            initial_request=initial_request,
            metadata=[('name', 'value')]
        )

        await rpc.open()

        while rpc.is_active:
            print(await rpc.recv())
            await rpc.send(example_pb2.StreamingRpcRequest(
                data='example'))

        await rpc.close()

    This does *not* retry the stream on errors.

    Args:
        start_rpc (grpc.aio.StreamStreamMultiCallable): The gRPC method used to
            start the RPC.
        initial_request (Union[ProtobufMessage,
                Callable[[], ProtobufMessage]]): The initial request to
            yield. This is useful if an initial request is needed to start the
            stream.
        metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
            the request.
    """

    def _create_queue(self) -> asyncio.Queue:
        """Create a queue for requests."""
        return asyncio.Queue()

    async def open(self) -> None:
        """Opens the stream."""
        if self.is_active:
            raise ValueError("Cannot open an already open stream.")

        request_generator = _AsyncRequestQueueGenerator(
            self._request_queue, initial_request=self._initial_request
        )
        try:
            call = await self._start_rpc(request_generator, metadata=self._rpc_metadata)
        except exceptions.GoogleAPICallError as exc:
            # The original `grpc.aio.AioRpcError` (which is usually also a
            # `grpc.aio.Call`) is available from the ``response`` property on
            # the mapped exception.
            self._on_call_done(exc.response)
            raise

        request_generator.call = call

        # TODO: api_core should expose the future interface for wrapped
        # callables as well.
        if hasattr(call, "_wrapped"):  # pragma: NO COVER
            call._wrapped.add_done_callback(self._on_call_done)
        else:
            call.add_done_callback(self._on_call_done)

        self._request_generator = request_generator
        self.call = call

    async def close(self) -> None:
        """Closes the stream."""
        if self.call is None:
            return

        await self._request_queue.put(None)
        self.call.cancel()
        self._request_generator = None
        self._initial_request = None
        self._callbacks = []
        # Don't set self.call to None. Keep it around so that send/recv can
        # raise the error.

    async def send(self, request: ProtobufMessage) -> None:
        """Queue a message to be sent on the stream.

        If the underlying RPC has been closed, this will raise.

        Args:
            request (ProtobufMessage): The request to send.
        """
        if self.call is None:
            raise ValueError("Cannot send on an RPC stream that has never been opened.")

        if not self.call.done():
            await self._request_queue.put(request)
        else:
            # calling read should cause the call to raise.
            await self.call.read()

    async def recv(self) -> ProtobufMessage:
        """Wait for a message to be returned from the stream.

        If the underlying RPC has been closed, this will raise.

        Returns:
            ProtobufMessage: The received message.
        """
        if self.call is None:
            raise ValueError("Cannot recv on an RPC stream that has never been opened.")

        return await self.call.read()

    @property
    def is_active(self) -> bool:
        """Whether the stream is currently open and active."""
        return self.call is not None and not self.call.done()
