//
// Copyright 2020 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

#ifndef GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
#define GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H

#include <grpc/support/port_platform.h>

#include <grpc/support/atm.h>
#include <grpc/support/log.h>
#include <grpc/support/sync.h>

#include <atomic>
#include <cassert>
#include <cinttypes>

#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gprpp/atomic.h"
#include "src/core/lib/gprpp/debug_location.h"
#include "src/core/lib/gprpp/orphanable.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"

namespace grpc_core {

// DualRefCounted is an interface for reference-counted objects with two
// classes of refs: strong refs (usually just called "refs") and weak refs.
// This supports cases where an object needs to start shutting down when
// all external callers are done with it (represented by strong refs) but
// cannot be destroyed until all internal callbacks are complete
// (represented by weak refs).
//
// Each class of refs can be incremented and decremented independently.
// Objects start with 1 strong ref and 0 weak refs at instantiation.
// When the strong refcount reaches 0, the object's Orphan() method is called.
// When the weak refcount reaches 0, the object is destroyed.
//
// This will be used by CRTP (curiously-recurring template pattern), e.g.:
//   class MyClass : public RefCounted<MyClass> { ... };
template <typename Child>
class DualRefCounted : public Orphanable {
 public:
  virtual ~DualRefCounted() = default;

  RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT {
    IncrementRefCount();
    return RefCountedPtr<Child>(static_cast<Child*>(this));
  }

  RefCountedPtr<Child> Ref(const DebugLocation& location,
                           const char* reason) GRPC_MUST_USE_RESULT {
    IncrementRefCount(location, reason);
    return RefCountedPtr<Child>(static_cast<Child*>(this));
  }

  void Unref() {
    // Convert strong ref to weak ref.
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
#ifndef NDEBUG
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p unref %d -> %d, weak_ref %d -> %d",
              trace_flag_->name(), this, strong_refs, strong_refs - 1,
              weak_refs, weak_refs + 1);
    }
    GPR_ASSERT(strong_refs > 0);
#endif
    if (GPR_UNLIKELY(strong_refs == 1)) {
      Orphan();
    }
    // Now drop the weak ref.
    WeakUnref();
  }
  void Unref(const DebugLocation& location, const char* reason) {
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
#ifndef NDEBUG
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p %s:%d unref %d -> %d, weak_ref %d -> %d) %s",
              trace_flag_->name(), this, location.file(), location.line(),
              strong_refs, strong_refs - 1, weak_refs, weak_refs + 1, reason);
    }
    GPR_ASSERT(strong_refs > 0);
#else
    // Avoid unused-parameter warnings for debug-only parameters
    (void)location;
    (void)reason;
#endif
    if (GPR_UNLIKELY(strong_refs == 1)) {
      Orphan();
    }
    // Now drop the weak ref.
    WeakUnref(location, reason);
  }

  RefCountedPtr<Child> RefIfNonZero() GRPC_MUST_USE_RESULT {
    uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE);
    do {
      const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
#ifndef NDEBUG
      const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
      if (trace_flag_ != nullptr && trace_flag_->enabled()) {
        gpr_log(GPR_INFO, "%s:%p ref_if_non_zero %d -> %d (weak_refs=%d)",
                trace_flag_->name(), this, strong_refs, strong_refs + 1,
                weak_refs);
      }
#endif
      if (strong_refs == 0) return nullptr;
    } while (!refs_.CompareExchangeWeak(
        &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL,
        MemoryOrder::ACQUIRE));
    return RefCountedPtr<Child>(static_cast<Child*>(this));
  }

  RefCountedPtr<Child> RefIfNonZero(const DebugLocation& location,
                                    const char* reason) GRPC_MUST_USE_RESULT {
    uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE);
    do {
      const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
#ifndef NDEBUG
      const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
      if (trace_flag_ != nullptr && trace_flag_->enabled()) {
        gpr_log(GPR_INFO,
                "%s:%p %s:%d ref_if_non_zero %d -> %d (weak_refs=%d) %s",
                trace_flag_->name(), this, location.file(), location.line(),
                strong_refs, strong_refs + 1, weak_refs, reason);
      }
#else
      // Avoid unused-parameter warnings for debug-only parameters
      (void)location;
      (void)reason;
#endif
      if (strong_refs == 0) return nullptr;
    } while (!refs_.CompareExchangeWeak(
        &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL,
        MemoryOrder::ACQUIRE));
    return RefCountedPtr<Child>(static_cast<Child*>(this));
  }

  WeakRefCountedPtr<Child> WeakRef() GRPC_MUST_USE_RESULT {
    IncrementWeakRefCount();
    return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
  }

  WeakRefCountedPtr<Child> WeakRef(const DebugLocation& location,
                                   const char* reason) GRPC_MUST_USE_RESULT {
    IncrementWeakRefCount(location, reason);
    return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
  }

  void WeakUnref() {
#ifndef NDEBUG
    // Grab a copy of the trace flag before the atomic change, since we
    // can't safely access it afterwards if we're going to be freed.
    auto* trace_flag = trace_flag_;
#endif
    const uint64_t prev_ref_pair =
        refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
#ifndef NDEBUG
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    if (trace_flag != nullptr && trace_flag->enabled()) {
      gpr_log(GPR_INFO, "%s:%p weak_unref %d -> %d (refs=%d)",
              trace_flag->name(), this, weak_refs, weak_refs - 1, strong_refs);
    }
    GPR_ASSERT(weak_refs > 0);
#endif
    if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
      delete static_cast<Child*>(this);
    }
  }
  void WeakUnref(const DebugLocation& location, const char* reason) {
#ifndef NDEBUG
    // Grab a copy of the trace flag before the atomic change, since we
    // can't safely access it afterwards if we're going to be freed.
    auto* trace_flag = trace_flag_;
#endif
    const uint64_t prev_ref_pair =
        refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
#ifndef NDEBUG
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    if (trace_flag != nullptr && trace_flag->enabled()) {
      gpr_log(GPR_INFO, "%s:%p %s:%d weak_unref %d -> %d (refs=%d) %s",
              trace_flag->name(), this, location.file(), location.line(),
              weak_refs, weak_refs - 1, strong_refs, reason);
    }
    GPR_ASSERT(weak_refs > 0);
#else
    // Avoid unused-parameter warnings for debug-only parameters
    (void)location;
    (void)reason;
#endif
    if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
      delete static_cast<Child*>(this);
    }
  }

  // Not copyable nor movable.
  DualRefCounted(const DualRefCounted&) = delete;
  DualRefCounted& operator=(const DualRefCounted&) = delete;

 protected:
  // TraceFlagT is defined to accept both DebugOnlyTraceFlag and TraceFlag.
  // Note: RefCount tracing is only enabled on debug builds, even when a
  //       TraceFlag is used.
  template <typename TraceFlagT = TraceFlag>
  explicit DualRefCounted(
      TraceFlagT*
#ifndef NDEBUG
          // Leave unnamed if NDEBUG to avoid unused parameter warning
          trace_flag
#endif
      = nullptr,
      int32_t initial_refcount = 1)
      :
#ifndef NDEBUG
        trace_flag_(trace_flag),
#endif
        refs_(MakeRefPair(initial_refcount, 0)) {
  }

 private:
  // Allow RefCountedPtr<> to access IncrementRefCount().
  template <typename T>
  friend class RefCountedPtr;
  // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount().
  template <typename T>
  friend class WeakRefCountedPtr;

  // First 32 bits are strong refs, next 32 bits are weak refs.
  static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) {
    return (static_cast<uint64_t>(strong) << 32) + static_cast<int64_t>(weak);
  }
  static uint32_t GetStrongRefs(uint64_t ref_pair) {
    return static_cast<uint32_t>(ref_pair >> 32);
  }
  static uint32_t GetWeakRefs(uint64_t ref_pair) {
    return static_cast<uint32_t>(ref_pair & 0xffffffffu);
  }

  void IncrementRefCount() {
#ifndef NDEBUG
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    GPR_ASSERT(strong_refs != 0);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p ref %d -> %d; (weak_refs=%d)",
              trace_flag_->name(), this, strong_refs, strong_refs + 1,
              weak_refs);
    }
#else
    refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
#endif
  }
  void IncrementRefCount(const DebugLocation& location, const char* reason) {
#ifndef NDEBUG
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    GPR_ASSERT(strong_refs != 0);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p %s:%d ref %d -> %d (weak_refs=%d) %s",
              trace_flag_->name(), this, location.file(), location.line(),
              strong_refs, strong_refs + 1, weak_refs, reason);
    }
#else
    // Use conditionally-important parameters
    (void)location;
    (void)reason;
    refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
#endif
  }

  void IncrementWeakRefCount() {
#ifndef NDEBUG
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p weak_ref %d -> %d; (refs=%d)",
              trace_flag_->name(), this, weak_refs, weak_refs + 1, strong_refs);
    }
#else
    refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
#endif
  }
  void IncrementWeakRefCount(const DebugLocation& location,
                             const char* reason) {
#ifndef NDEBUG
    const uint64_t prev_ref_pair =
        refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
    const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
    const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
    if (trace_flag_ != nullptr && trace_flag_->enabled()) {
      gpr_log(GPR_INFO, "%s:%p %s:%d weak_ref %d -> %d (refs=%d) %s",
              trace_flag_->name(), this, location.file(), location.line(),
              weak_refs, weak_refs + 1, strong_refs, reason);
    }
#else
    // Use conditionally-important parameters
    (void)location;
    (void)reason;
    refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
#endif
  }

#ifndef NDEBUG
  TraceFlag* trace_flag_;
#endif
  Atomic<uint64_t> refs_;
};

}  // namespace grpc_core

#endif /* GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H */
