/*
 *
 * Copyright 2018 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

#include <condition_variable>

#include <gtest/gtest.h>

#include <grpcpp/channel.h>

#include "src/proto/grpc/testing/echo.grpc.pb.h"
#include "test/cpp/util/string_ref_helper.h"

namespace grpc {
namespace testing {
/* This interceptor does nothing. Just keeps a global count on the number of
 * times it was invoked. */
class PhonyInterceptor : public experimental::Interceptor {
 public:
  PhonyInterceptor() {}

  void Intercept(experimental::InterceptorBatchMethods* methods) override {
    if (methods->QueryInterceptionHookPoint(
            experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
      num_times_run_++;
    } else if (methods->QueryInterceptionHookPoint(
                   experimental::InterceptionHookPoints::
                       POST_RECV_INITIAL_METADATA)) {
      num_times_run_reverse_++;
    } else if (methods->QueryInterceptionHookPoint(
                   experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
      num_times_cancel_++;
    }
    methods->Proceed();
  }

  static void Reset() {
    num_times_run_.store(0);
    num_times_run_reverse_.store(0);
    num_times_cancel_.store(0);
  }

  static int GetNumTimesRun() {
    EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
    return num_times_run_.load();
  }

  static int GetNumTimesCancel() { return num_times_cancel_.load(); }

 private:
  static std::atomic<int> num_times_run_;
  static std::atomic<int> num_times_run_reverse_;
  static std::atomic<int> num_times_cancel_;
};

class PhonyInterceptorFactory
    : public experimental::ClientInterceptorFactoryInterface,
      public experimental::ServerInterceptorFactoryInterface {
 public:
  experimental::Interceptor* CreateClientInterceptor(
      experimental::ClientRpcInfo* /*info*/) override {
    return new PhonyInterceptor();
  }

  experimental::Interceptor* CreateServerInterceptor(
      experimental::ServerRpcInfo* /*info*/) override {
    return new PhonyInterceptor();
  }
};

/* This interceptor can be used to test the interception mechanism. */
class TestInterceptor : public experimental::Interceptor {
 public:
  TestInterceptor(const std::string& method, const char* suffix_for_stats,
                  experimental::ClientRpcInfo* info) {
    EXPECT_EQ(info->method(), method);

    if (suffix_for_stats == nullptr || info->suffix_for_stats() == nullptr) {
      EXPECT_EQ(info->suffix_for_stats(), suffix_for_stats);
    } else {
      EXPECT_EQ(strcmp(info->suffix_for_stats(), suffix_for_stats), 0);
    }
  }

  void Intercept(experimental::InterceptorBatchMethods* methods) override {
    methods->Proceed();
  }
};

class TestInterceptorFactory
    : public experimental::ClientInterceptorFactoryInterface {
 public:
  TestInterceptorFactory(const std::string& method,
                         const char* suffix_for_stats)
      : method_(method), suffix_for_stats_(suffix_for_stats) {}

  experimental::Interceptor* CreateClientInterceptor(
      experimental::ClientRpcInfo* info) override {
    return new TestInterceptor(method_, suffix_for_stats_, info);
  }

 private:
  std::string method_;
  const char* suffix_for_stats_;
};

/* This interceptor factory returns nullptr on interceptor creation */
class NullInterceptorFactory
    : public experimental::ClientInterceptorFactoryInterface,
      public experimental::ServerInterceptorFactoryInterface {
 public:
  experimental::Interceptor* CreateClientInterceptor(
      experimental::ClientRpcInfo* /*info*/) override {
    return nullptr;
  }

  experimental::Interceptor* CreateServerInterceptor(
      experimental::ServerRpcInfo* /*info*/) override {
    return nullptr;
  }
};

class EchoTestServiceStreamingImpl : public EchoTestService::Service {
 public:
  ~EchoTestServiceStreamingImpl() override {}

  Status Echo(ServerContext* context, const EchoRequest* request,
              EchoResponse* response) override {
    auto client_metadata = context->client_metadata();
    for (const auto& pair : client_metadata) {
      context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
    }
    response->set_message(request->message());
    return Status::OK;
  }

  Status BidiStream(
      ServerContext* context,
      grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
    EchoRequest req;
    EchoResponse resp;
    auto client_metadata = context->client_metadata();
    for (const auto& pair : client_metadata) {
      context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
    }

    while (stream->Read(&req)) {
      resp.set_message(req.message());
      EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
    }
    return Status::OK;
  }

  Status RequestStream(ServerContext* context,
                       ServerReader<EchoRequest>* reader,
                       EchoResponse* resp) override {
    auto client_metadata = context->client_metadata();
    for (const auto& pair : client_metadata) {
      context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
    }

    EchoRequest req;
    string response_str = "";
    while (reader->Read(&req)) {
      response_str += req.message();
    }
    resp->set_message(response_str);
    return Status::OK;
  }

  Status ResponseStream(ServerContext* context, const EchoRequest* req,
                        ServerWriter<EchoResponse>* writer) override {
    auto client_metadata = context->client_metadata();
    for (const auto& pair : client_metadata) {
      context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
    }

    EchoResponse resp;
    resp.set_message(req->message());
    for (int i = 0; i < 10; i++) {
      EXPECT_TRUE(writer->Write(resp));
    }
    return Status::OK;
  }
};

constexpr int kNumStreamingMessages = 10;

void MakeCall(const std::shared_ptr<Channel>& channel,
              const StubOptions& options = StubOptions());

void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);

void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);

void MakeCallbackCall(const std::shared_ptr<Channel>& channel);

bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
                   const string& key, const string& value);

bool CheckMetadata(const std::multimap<std::string, std::string>& map,
                   const string& key, const string& value);

std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreatePhonyClientInterceptors();

inline void* tag(int i) { return reinterpret_cast<void*>(i); }
inline int detag(void* p) {
  return static_cast<int>(reinterpret_cast<intptr_t>(p));
}

class Verifier {
 public:
  Verifier() : lambda_run_(false) {}
  // Expect sets the expected ok value for a specific tag
  Verifier& Expect(int i, bool expect_ok) {
    return ExpectUnless(i, expect_ok, false);
  }
  // ExpectUnless sets the expected ok value for a specific tag
  // unless the tag was already marked seen (as a result of ExpectMaybe)
  Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
    if (!seen) {
      expectations_[tag(i)] = expect_ok;
    }
    return *this;
  }
  // ExpectMaybe sets the expected ok value for a specific tag, but does not
  // require it to appear
  // If it does, sets *seen to true
  Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
    if (!*seen) {
      maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
    }
    return *this;
  }

  // Next waits for 1 async tag to complete, checks its
  // expectations, and returns the tag
  int Next(CompletionQueue* cq, bool ignore_ok) {
    bool ok;
    void* got_tag;
    EXPECT_TRUE(cq->Next(&got_tag, &ok));
    GotTag(got_tag, ok, ignore_ok);
    return detag(got_tag);
  }

  template <typename T>
  CompletionQueue::NextStatus DoOnceThenAsyncNext(
      CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
      std::function<void(void)> lambda) {
    if (lambda_run_) {
      return cq->AsyncNext(got_tag, ok, deadline);
    } else {
      lambda_run_ = true;
      return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
    }
  }

  // Verify keeps calling Next until all currently set
  // expected tags are complete
  void Verify(CompletionQueue* cq) { Verify(cq, false); }

  // This version of Verify allows optionally ignoring the
  // outcome of the expectation
  void Verify(CompletionQueue* cq, bool ignore_ok) {
    GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
    while (!expectations_.empty()) {
      Next(cq, ignore_ok);
    }
  }

  // This version of Verify stops after a certain deadline, and uses the
  // DoThenAsyncNext API
  // to call the lambda
  void Verify(CompletionQueue* cq,
              std::chrono::system_clock::time_point deadline,
              const std::function<void(void)>& lambda) {
    if (expectations_.empty()) {
      bool ok;
      void* got_tag;
      EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
                CompletionQueue::TIMEOUT);
    } else {
      while (!expectations_.empty()) {
        bool ok;
        void* got_tag;
        EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
                  CompletionQueue::GOT_EVENT);
        GotTag(got_tag, ok, false);
      }
    }
  }

 private:
  void GotTag(void* got_tag, bool ok, bool ignore_ok) {
    auto it = expectations_.find(got_tag);
    if (it != expectations_.end()) {
      if (!ignore_ok) {
        EXPECT_EQ(it->second, ok);
      }
      expectations_.erase(it);
    } else {
      auto it2 = maybe_expectations_.find(got_tag);
      if (it2 != maybe_expectations_.end()) {
        if (it2->second.seen != nullptr) {
          EXPECT_FALSE(*it2->second.seen);
          *it2->second.seen = true;
        }
        if (!ignore_ok) {
          EXPECT_EQ(it2->second.ok, ok);
        }
      } else {
        gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
        abort();
      }
    }
  }

  struct MaybeExpect {
    bool ok;
    bool* seen;
  };

  std::map<void*, bool> expectations_;
  std::map<void*, MaybeExpect> maybe_expectations_;
  bool lambda_run_;
};

}  // namespace testing
}  // namespace grpc