/*******************************************************************************
 *
 * MIT License
 *
 * Copyright 2025 AMD ROCm(TM) Software
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/

#pragma once

#include "OrderMultiplyNodes.hpp"

#include <optional>
#include <unordered_map>
#include <vector>

#include <rocRoller/KernelGraph/ControlGraph/ControlFlowRWTracer.hpp>

namespace rocRoller
{
    namespace KernelGraph
    {
        namespace OrderMultiplyNodesDetail
        {
            /**
             * Returns multiply nodes in graph grouped by their body parent.
             */
            std::unordered_map<int, std::vector<int>>
                getGroupedMultiplyNodes(KernelGraph const& graph);

            /**
             * Creates a sub-graph of the given nodes.
             *
             * The sub-graph is created by adding the given nodes to a new control graph,
             * and then adding sequence edges between the nodes based on the order of the
             * nodes in the original control graph.
             */
            ControlGraph::ControlGraph createSubGraph(KernelGraph const&      graph,
                                                      std::vector<int> const& nodes);

            /**
             * Sorts `nodes` according to existing order and according to BestNodeOrder.
             *
             * `nodes` must be a collection of nodes directly within the same body-parent in
             * `graph.control`.
             */
            void orderNodes(KernelGraph const& graph, std::vector<int>& nodes);

            /**
             * Comparator for ordering multiply nodes.
             *
             * This is used to order the multiply nodes in each group.
             *
             * Note that this is NOT guaranteed to be a strict weak ordering and should not be
             * used with `std::sort`.
             *
             * The order is determined by the following criteria:
             * 1. If available, use downstream memory nodes, to enable memory nodes
             *    to be scheduled earlier in some kernels.
             * 2. Otherwise if available, use last upstream tag dependencies, to prioritize
             *    multiplies that will have lower waitcount values.
             * 3. Otherwise use integer comparison as a last resort.
             */
            struct BestNodeOrder
            {
                BestNodeOrder(KernelGraph const& graph);

                bool operator()(int a, int b) const;

                /**
                 * Looks for memory nodes downstream of `node` in a breadth-first search of the
                 * control graph. Returns the first memory node found, or `std::nullopt` if none
                 * exists.
                 */
                std::optional<int> downstreamMemoryNode(int node) const;

                /**
                 * Looks for memory nodes downstream in a breadth-first search of the control
                 * graph from `a` and `b`. If these are found and are different from each other,
                 * it will return an order based on the control graph order of those memory nodes.
                 *
                 * For kernels that are not double buffered, this can be used to prioritize
                 * multiply nodes that will enabled memory nodes to be scheduled earlier in the
                 * kernel, overlapping with other multiply nodes.
                 *
                 * For double-buffered kernels, this will generally be the same memory node for
                 * both `a` and `b`, so this will return `std::nullopt`.
                 */
                std::optional<bool> orderByDownstreamMemoryNodes(int a, int b) const;

                /**
                 * Looks for data flow tags that are read by `node`. Then looks for control
                 * nodes that write to those tags, and are before `node`. Returns those nodes in
                 * reverse topological order (i.e. from latest to earliest).
                 */
                std::vector<int> const& reversedTagDependencies(int node) const;

                /**
                 * Calls `reversedTagDependencies()` on `a` and `b`. Then will do an
                 * element-wise comparison using `compareNodes()` on those results.
                 */
                std::optional<bool> orderByLastTagDependencies(int a, int b) const;

            private:
                /**
                 * Helper function. Returns the existing order of the two nodes if it is
                 * defined, otherwise returns std::nullopt.
                 */
                std::optional<bool> existingOrder(int a, int b) const;

                KernelGraph const&                                  m_graph;
                ControlFlowRWTracer                                 m_tracer;
                mutable std::unordered_map<int, std::optional<int>> m_downstreamMemoryNodes;
                mutable std::unordered_map<int, std::vector<int>>   m_reversedTagDependencies;
            };
        }
    }
}
