//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

// Modifications Copyright (c) 2024-2025 Advanced Micro Devices, Inc.
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef _CUDA_STREAM_REF
#define _CUDA_STREAM_REF

/*
    stream_ref synopsis
namespace cuda {
class stream_ref {
    using value_type = hipStream_t;

    stream_ref() = default;
    stream_ref(hipStream_t stream_) noexcept : stream(stream_) {}

    stream_ref(int) = delete;
    stream_ref(nullptr_t) = delete;

    [[nodiscard]] value_type get() const noexcept;

    void wait() const;

    [[nodiscard]] bool ready() const;

    [[nodiscard]] friend bool operator==(stream_ref, stream_ref);
    [[nodiscard]] friend bool operator!=(stream_ref, stream_ref);

private:
  hipStream_t stream = 0; // exposition only
};
}  // cuda
*/

// cuda_runtime_api needs to come first
#include <hip/hip_runtime.h>
#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/__exception/cuda_error.h>
#include <cuda/std/cstddef>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

/**
 * \brief A non-owning wrapper for a `hipStream_t`.
 */
class stream_ref
{
protected:
  ::hipStream_t __stream{0};

public:
  using value_type = ::hipStream_t;

  /**
   * \brief Constructs a `stream_ref` of the "default" CUDA stream.
   *
   * For behavior of the default stream,
   * \see
   * https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
   *
   */
  _CCCL_HIDE_FROM_ABI stream_ref() = default;

  /**
   * \brief Constructs a `stream_ref` from a `hipStream_t` handle.
   *
   * This constructor provides implicit conversion from `hipStream_t`.
   *
   * \note: It is the callers responsibilty to ensure the `stream_ref` does not
   * outlive the stream identified by the `hipStream_t` handle.
   *
   */
  constexpr stream_ref(value_type __stream_) noexcept
      : __stream{__stream_}
  {}

  /// Disallow construction from an `int`, e.g., `0`.
  stream_ref(int) = delete;

  /// Disallow construction from `nullptr`.
  stream_ref(_CUDA_VSTD::nullptr_t) = delete;

  /**
   * \brief Compares two `stream_ref`s for equality
   *
   * \note Allows comparison with `hipStream_t` due to implicit conversion to
   * `stream_ref`.
   *
   * \param lhs The first `stream_ref` to compare
   * \param rhs The second `stream_ref` to compare
   * \return true if equal, false if unequal
   */
  _CCCL_NODISCARD_FRIEND constexpr bool operator==(const stream_ref& __lhs, const stream_ref& __rhs) noexcept
  {
    return __lhs.__stream == __rhs.__stream;
  }

  /**
   * \brief Compares two `stream_ref`s for inequality
   *
   * \note Allows comparison with `hipStream_t` due to implicit conversion to
   * `stream_ref`.
   *
   * \param lhs The first `stream_ref` to compare
   * \param rhs The second `stream_ref` to compare
   * \return true if unequal, false if equal
   */
  _CCCL_NODISCARD_FRIEND constexpr bool operator!=(const stream_ref& __lhs, const stream_ref& __rhs) noexcept
  {
    return __lhs.__stream != __rhs.__stream;
  }

  /// Returns the wrapped `hipStream_t` handle.
  _CCCL_NODISCARD constexpr value_type get() const noexcept
  {
    return __stream;
  }

  /**
   * \brief Synchronizes the wrapped stream.
   *
   * \throws hip::cuda_error if synchronization fails.
   *
   */
  void wait() const
  {
    _CCCL_TRY_CUDA_API(::hipStreamSynchronize, "Failed to synchronize stream.", get());
  }

  /**
   * \brief Queries if all operations on the wrapped stream have completed.
   *
   * \throws hip::cuda_error if the query fails.
   *
   * \return `true` if all operations have completed, or `false` if not.
   */
  _CCCL_NODISCARD bool ready() const
  {
    const auto __result = ::hipStreamQuery(get());
    if (__result == ::hipErrorNotReady)
    {
      return false;
    }
    switch (__result)
    {
      case ::hipSuccess:
        break;
      default:
        (void)::hipGetLastError(); // Clear CUDA error state
        ::hip::__throw_cuda_error(__result, "Failed to query stream.");
    }
    return true;
  }

  /**
   * \brief Queries the priority of the wrapped stream.
   *
   * \throws hip::cuda_error if the query fails.
   *
   * \return value representing the priority of the wrapped stream.
   */
  _CCCL_NODISCARD int priority() const
  {
    int __result = 0;
    _CCCL_TRY_CUDA_API(::hipStreamGetPriority, "Failed to get stream priority", get(), &__result);
    return __result;
  }
};

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif //_CUDA_STREAM_REF
