/*******************************************************************************
 *
 * MIT License
 *
 * Copyright 2024-2025 AMD ROCm(TM) Software
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/

#pragma once

#include <algorithm>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <vector>

#include <omp.h>

namespace DGen
{
    /**
     * @brief Helper to compute product of elements in a range
     */
    template <typename T>
    inline size_t product(std::vector<T> const& x)
    {
        return std::accumulate(x.begin(), x.end(), size_t(1), std::multiplies<size_t>());
    }

    /**
     * @brief Compute strides for column-major layout given sizes
     */
    inline std::vector<size_t> computeStrides(std::vector<size_t> const& sizes)
    {
        std::vector<size_t> strides(sizes.size());
        if(sizes.empty())
            return strides;

        strides[0] = 1;
        for(size_t i = 1; i < sizes.size(); ++i)
            strides[i] = strides[i - 1] * sizes[i - 1];

        return strides;
    }

    /**
     * @brief Compute shuffled strides given dimension order
     */
    inline std::vector<size_t> computeShuffledStrides(std::vector<size_t> const& sizes,
                                                       std::vector<size_t> const& dimOrder)
    {
        std::vector<size_t> strides(sizes.size(), 0);
        size_t              stride = 1;
        for(auto idx : dimOrder)
        {
            strides.at(idx) = stride;
            stride *= sizes.at(idx);
        }
        return strides;
    }

    /**
     * @brief Shuffle data according to dimension reordering
     *
     * This performs a dimension shuffle where:
     * - input is arranged according to srcStrides
     * - output is arranged according to dstStrides
     * - both have the same dimension sizes
     */
    template <typename T>
    inline std::vector<T> shuffleDims(std::vector<T> const&      input,
                                      std::vector<size_t> const& sizes,
                                      std::vector<size_t> const& dstStrides,
                                      std::vector<size_t> const& srcStrides)
    {
        if(sizes.size() != dstStrides.size() || sizes.size() != srcStrides.size())
            throw std::runtime_error("shuffleDims: size/stride dimension mismatch");

        if(sizes.size() < 2)
            throw std::runtime_error("shuffleDims: need at least 2 dimensions");

        size_t totalElements = product(sizes);
        if(input.size() != totalElements)
        {
            std::ostringstream msg;
            msg << "shuffleDims: input size " << input.size() << " doesn't match expected "
                << totalElements;
            throw std::runtime_error(msg.str());
        }

        std::vector<T> output(input.size());

        // Compute total number of coordinates
        size_t totalCoords = 1;
        for(size_t i = 0; i < sizes.size(); ++i)
            totalCoords *= sizes[i];

#pragma omp parallel for
        for(size_t coordNum = 0; coordNum < totalCoords; ++coordNum)
        {
            // Convert coordNum to N-D coordinates
            std::vector<size_t> coord(sizes.size());
            size_t              remaining = coordNum;
            for(size_t i = 0; i < sizes.size(); ++i)
            {
                coord[i] = remaining % sizes[i];
                remaining /= sizes[i];
            }

            // Compute source and destination indices using strides
            size_t srcIdx = 0;
            size_t dstIdx = 0;
            for(size_t i = 0; i < sizes.size(); ++i)
            {
                srcIdx += coord[i] * srcStrides[i];
                dstIdx += coord[i] * dstStrides[i];
            }

            output[dstIdx] = input[srcIdx];
        }

        return output;
    }

    /**
     * @brief Pre-swizzle and optionally pre-tile the input.
     *
     * This function rearranges tensor data according to swizzle and tile configurations.
     * The incoming data should be in row-major order with the 0 dimension being the
     * fastest (smallest stride).
     *
     * @param input The input data vector
     * @param sizes The dimension sizes {size0, size1}
     * @param preSwizzleSize The swizzle configuration {tileMN, tileK, subTileK}, or empty
     * @param preTileSize The pre-tile configuration {tileSize0, tileSize1}, or empty
     * @return The pre-swizzled/pre-tiled data
     */
    template <typename T>
    inline std::vector<T> preSwizzle(std::vector<T> const&      input,
                                     std::vector<size_t> const& sizes,
                                     std::vector<size_t> const& preSwizzleSize,
                                     std::vector<size_t> const& preTileSize)
    {
        if(!preSwizzleSize.empty())
        {
            if(preSwizzleSize.size() != 3)
            {
                std::ostringstream msg;
                msg << "preSwizzle: preSwizzleSize must have 3 elements, got "
                    << preSwizzleSize.size();
                throw std::runtime_error(msg.str());
            }
        }

        if(sizes.size() != 2)
        {
            std::ostringstream msg;
            msg << "preSwizzle: Batch dimension not yet supported. sizes.size()=" << sizes.size();
            throw std::runtime_error(msg.str());
        }

        size_t totalElements = product(sizes);
        if(totalElements != input.size())
        {
            std::ostringstream msg;
            msg << "preSwizzle: input size " << input.size() << " doesn't match sizes product "
                << totalElements;
            throw std::runtime_error(msg.str());
        }

        std::vector<size_t> srcSizes, dimOrder;

        if((!preSwizzleSize.empty()) && (preTileSize.empty()))
        {
            auto tileMN   = preSwizzleSize[0];
            auto tileK    = preSwizzleSize[1];
            auto subTileK = preSwizzleSize[2];

            if(tileMN != 64 && tileMN != 32)
            {
                std::ostringstream msg;
                msg << "preSwizzle: tileMN must be 32 or 64, got " << tileMN;
                throw std::runtime_error(msg.str());
            }

            if(tileK % 4 != 0)
            {
                std::ostringstream msg;
                msg << "preSwizzle: tileK must be a multiple of 4, got " << tileK;
                throw std::runtime_error(msg.str());
            }

            size_t nLanesPerSIMD   = 16;
            size_t nSIMDsPerWave   = 4;
            size_t nSIMDIndex      = tileMN / nLanesPerSIMD;
            size_t nSIMDBlock      = nSIMDsPerWave / nSIMDIndex;
            size_t nVGPRIndex      = std::min(nSIMDIndex, subTileK);
            size_t nVGPRBlock      = tileK / nSIMDBlock / nVGPRIndex;
            size_t nSIMDIndexBlock = nVGPRIndex;
            size_t nSIMDIndexIndex = nSIMDIndex / nSIMDIndexBlock;

            if(nVGPRIndex * nVGPRBlock * nSIMDBlock != tileK)
            {
                std::ostringstream msg;
                msg << "preSwizzle: nVGPRIndex * nVGPRBlock * nSIMDBlock != tileK";
                throw std::runtime_error(msg.str());
            }

            if(nLanesPerSIMD * nSIMDIndexIndex * nSIMDIndexBlock != tileMN)
            {
                std::ostringstream msg;
                msg << "preSwizzle: nLanesPerSIMD * nSIMDIndexIndex * nSIMDIndexBlock != tileMN";
                throw std::runtime_error(msg.str());
            }

            srcSizes = {nVGPRIndex,
                        nVGPRBlock,
                        nSIMDBlock,
                        sizes[0] / (tileK),
                        nLanesPerSIMD,
                        nSIMDIndexIndex,
                        nSIMDIndexBlock,
                        sizes[1] / (tileMN)};

            if(tileMN == 64)
            {
                // Pre swizzle: swap nSIMDIndexBlock (6) and nVGPRIndex (0)
                dimOrder = {6, 1, 2, 3, 4, 5, 0, 7};
            }
            else if(tileMN == 32 && subTileK == 4)
            {
                // Pre swizzle: swap nSIMDIndexBlock (6) and nVGPRIndex (0)
                //              swap nSIMDBlock (2) and nVGPRBlock (1)
                dimOrder = {6, 2, 1, 3, 4, 5, 0, 7};
            }
            else if(tileMN == 32 && subTileK == 2)
            {
                // Pre swizzle: rotate nVGPRIndex (0), nVGPRBlock (1), nSIMDBlock (2)
                dimOrder = {1, 2, 0, 3, 4, 5, 6, 7};
            }
        }
        else if((preSwizzleSize.empty()) && (!preTileSize.empty()))
        {
            srcSizes = {preTileSize[0],
                        sizes[0] / preTileSize[0],
                        preTileSize[1],
                        sizes[1] / preTileSize[1]};

            // Pre-tiling: 1 and 3 are pushed to the back (they become the slowest)
            dimOrder = {0, 2, 1, 3};
        }
        else
        {
            auto tileMN   = preSwizzleSize[0];
            auto tileK    = preSwizzleSize[1];
            auto subTileK = preSwizzleSize[2];

            if(tileMN != 64 && tileMN != 32)
            {
                std::ostringstream msg;
                msg << "preSwizzle: tileMN must be 32 or 64, got " << tileMN;
                throw std::runtime_error(msg.str());
            }

            if(tileK % 4 != 0)
            {
                std::ostringstream msg;
                msg << "preSwizzle: tileK must be a multiple of 4, got " << tileK;
                throw std::runtime_error(msg.str());
            }

            size_t ptTileSizeK     = preTileSize[0];
            size_t ptTileSizeMN    = preTileSize[1];
            size_t nLanesPerSIMD   = 16;
            size_t nSIMDsPerWave   = 4;
            size_t nSIMDIndex      = tileMN / nLanesPerSIMD;
            size_t nSIMDBlock      = nSIMDsPerWave / nSIMDIndex;
            size_t nVGPRIndex      = std::min(nSIMDIndex, subTileK);
            size_t nVGPRBlock      = tileK / nSIMDBlock / nVGPRIndex;
            size_t nSIMDIndexBlock = nVGPRIndex;
            size_t nSIMDIndexIndex = nSIMDIndex / nSIMDIndexBlock;

            if(ptTileSizeK / tileK == 0)
            {
                std::ostringstream msg;
                msg << "preSwizzle: ptTileSizeK / tileK == 0, ptTileSizeK=" << ptTileSizeK
                    << ", tileK=" << tileK;
                throw std::runtime_error(msg.str());
            }

            if(ptTileSizeMN / tileMN == 0)
            {
                std::ostringstream msg;
                msg << "preSwizzle: ptTileSizeMN / tileMN == 0, ptTileSizeMN=" << ptTileSizeMN
                    << ", tileMN=" << tileMN;
                throw std::runtime_error(msg.str());
            }

            if(nVGPRIndex * nVGPRBlock * nSIMDBlock != tileK)
            {
                std::ostringstream msg;
                msg << "preSwizzle: nVGPRIndex * nVGPRBlock * nSIMDBlock != tileK";
                throw std::runtime_error(msg.str());
            }

            if(nLanesPerSIMD * nSIMDIndexIndex * nSIMDIndexBlock != tileMN)
            {
                std::ostringstream msg;
                msg << "preSwizzle: nLanesPerSIMD * nSIMDIndexIndex * nSIMDIndexBlock != tileMN";
                throw std::runtime_error(msg.str());
            }

            srcSizes = {nVGPRIndex,
                        nVGPRBlock,
                        nSIMDBlock,
                        ptTileSizeK / tileK,
                        sizes[0] / ptTileSizeK,
                        nLanesPerSIMD,
                        nSIMDIndexIndex,
                        nSIMDIndexBlock,
                        ptTileSizeMN / tileMN,
                        sizes[1] / ptTileSizeMN};

            if(tileMN == 64)
            {
                // Pre swizzle: swap nSIMDIndexBlock (7) and nVGPRIndex (0)
                // Pre tile: push workgroup tiles (4 and 9) to the end
                dimOrder = {7, 1, 2, 3, 5, 6, 0, 8, 4, 9};
            }
            else if(tileMN == 32 && subTileK == 4)
            {
                // Pre swizzle: swap nSIMDIndexBlock (7) and nVGPRIndex (0)
                //              swap nSIMDBlock (2) and nVGPRBlock (1)
                // Pre tile: push workgroup tiles (4 and 9) to the end
                dimOrder = {7, 2, 1, 3, 5, 6, 0, 8, 4, 9};
            }
            else if(tileMN == 32 && subTileK == 2)
            {
                // Pre swizzle: rotate nVGPRIndex (0), nVGPRBlock (1), nSIMDBlock (2)
                // Pre tile: push workgroup tiles (4 and 9) to the end
                dimOrder = {1, 2, 0, 3, 5, 6, 7, 8, 4, 9};
            }
        }

        if(product(srcSizes) != product(sizes))
        {
            std::ostringstream msg;
            msg << "PreSwizzle size mismatch: product(srcSizes)=" << product(srcSizes)
                << " != product(sizes)=" << product(sizes);
            throw std::runtime_error(msg.str());
        }

        if(srcSizes.empty())
            throw std::runtime_error("PreSwizzle source size not populated.");

        if(dimOrder.empty())
            throw std::runtime_error("PreSwizzle permutation order not populated.");

        auto srcStrides = computeStrides(srcSizes);
        auto dstStrides = computeShuffledStrides(srcSizes, dimOrder);

        return shuffleDims(input, srcSizes, dstStrides, srcStrides);
    }

} // namespace DGen

