/*
 * Copyright (C) 2022-2024 Kernkonzept GmbH.
 * Author(s): Jan Klötzke <jan.kloetzke@kernkonzept.com>
 *
 * License: see LICENSE.spdx (in this directory or the directories above)
 */

#include <cstring>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fcntl.h>

#include <l4/util/util.h>
#include <l4/sys/cxx/ipc_epiface>
#include <l4/re/error_helper>
#include <l4/l4virtio/client/virtio-block>

#include "debug.h"
#include "device_factory.h"
#include "guest.h"
#include "mmio_device.h"

namespace {

class Cfi_backend
{
public:
  virtual ~Cfi_backend() = default;

  virtual char const *dev_name() = 0;
  virtual char *local_addr() const = 0;
  virtual l4_size_t mapped_size() const = 0;
  virtual void taint(l4_addr_t off, l4_size_t len = 1) = 0;
  virtual void write_back() = 0;
};

class Cfi_backend_dataspace : public Cfi_backend
{
public:
  Cfi_backend_dataspace(L4::Cap<L4Re::Dataspace> ds, l4_uint64_t size, bool ro)
  {
    auto flags = ro ? L4Re::Rm::F::R : L4Re::Rm::F::RW;
    _mgr = cxx::make_unique<Vmm::Ds_manager>(dev_name(), ds, 0, size, flags);
  }

  char const *dev_name() override
  { return "Cfi_flash-ds"; }

  char *local_addr() const override
  { return _mgr->local_addr<char *>(); }

  l4_size_t mapped_size() const override
  { return _mgr->size(); }

  void taint(l4_addr_t, l4_size_t) override
  { /* nothing to do */ }

  void write_back() override
  { /* nothing to do */ }

private:
  cxx::unique_ptr<Vmm::Ds_manager> _mgr;
};

class Cfi_backend_virtio_block final : public Cfi_backend
{
  enum { Sector_size = 512 };

public:
  Cfi_backend_virtio_block(L4::Cap<L4virtio::Device> cap, l4_uint64_t size)
  : _size(l4_round_page(size))
  {
    _dev.setup_device(cap, _size, &_localaddr, _devaddr);

    l4_uint64_t sectors = _dev.device_config().capacity;
    if (_size > sectors * Sector_size)
      L4Re::throw_error(-L4_EINVAL,
                        "Block device size too small for CFI registers.");

    if (_dev.feature_negotiated(5 /* VIRTIO_BLK_F_RO */))
      L4Re::throw_error(-L4_EINVAL,
                        "CFI: virtio device read only. Not supported.");

    // read device up-front
    auto h = _dev.start_request(0, L4VIRTIO_BLOCK_T_IN, 0);
    L4Re::chksys(_dev.add_block(h, _devaddr, _size),
                 "CFI: Error during virtio setup: add block failed.");
    L4Re::chksys(_dev.process_request(h),
                 "CFI: Error during virtio setup: process request failed.");

    // bring in pages
    l4_touch_ro(_localaddr, _size);
  }

  char const *dev_name() override
  { return "Cfi_flash-virtio-blk"; }

  char *local_addr() const override
  { return reinterpret_cast<char *>(_localaddr); }

  l4_size_t mapped_size() const override
  { return _size; }

  void taint(l4_addr_t off, l4_size_t len) override
  {
    l4_addr_t start = off / Sector_size;
    l4_addr_t end = (off + len - 1U) / Sector_size;

    if (end + 1U < _dirty_start || _dirty_end + 1U < start)
      write_back();

    if (start < _dirty_start)
      _dirty_start = start;
    if (end > _dirty_end)
      _dirty_end = end;
  }

  void write_back() override
  {
    if (_dirty_start <= _dirty_end)
      {
        l4_size_t blocks = _dirty_end - _dirty_start + 1U;
        auto da = L4virtio::Ptr<void>(_devaddr.get() + _dirty_start * Sector_size);
        auto h = _dev.start_request(_dirty_start, L4VIRTIO_BLOCK_T_OUT, 0);

        // There is no way to recover from errors here.
        // At least tell the user something went wrong.
        if(_dev.add_block(h, da, blocks * Sector_size) < 0)
          warn().printf("write_back: add block failed\n");
        else if (_dev.process_request(h) < 0)
          warn().printf("write_back: process request failed\n");

        _dirty_start = ~0UL;
        _dirty_end = 0;
      }
  }

private:
  static Dbg warn() { return Dbg(Dbg::Dev, Dbg::Warn, "CFI(vio)"); }

  l4_size_t _size;
  void *_localaddr = nullptr;
  L4virtio::Ptr<void> _devaddr;
  L4virtio::Driver::Block_device _dev;

  l4_addr_t _dirty_start = ~0UL;
  l4_addr_t _dirty_end = 0;
};

/**
 * Simple CFI compliant flash with Intel command set.
 *
 * Example device tree:
 *
 * \code{.dtb}
 *   flash@ffc00000 {
 *       compatible = "cfi-flash";
 *       reg = <0x0 0xffc00000 0x0 0x84000>;
 *       l4vmm,dscap = "capname";
 *       erase-size = <0x10000>; // must be power of two
 *       bank-width = <4>;
 *       device-width = <2>; // optional, equal to bank-width by default
 *   };
 * \endcode
 *
 * 'bank-width' configures the total bus width of the flash (in bytes).
 * It is typically equal to the 'device-width', unless multiple flash chips
 * share the bus. In this case 'device-width' refers to the width of a single
 * chip. The example above configures a 32-bit wide flash that consists of
 * two 16-bit chips.
 *
 * The optional "read-only" property will make the device read-only. If the
 * dataspace 'capname' is read-only but the 'read-only' property is not set,
 * this emulation will make the flash device read-only as well (but give a
 * warning about it).
 *
 * How to make the dscap writable:
 * 1. Load the bootmodule (modules.list):
 *    module OVMF_VARS.fd :rw
 * 2. Add the bootmodule to the caps table of uvmm:
 *    uvmm_caps = {
 *      capname = L4.Env.rwfs:query("OVMF_VARS.fd", 7),
 *    }
 *    L4.Env.loader:startv(caps=uvmm_caps, "rom/uvmm")
 *
 * Optionally, this CFI emulation also supports virtio-block as backend.
 * By replacing 'l4vmm,dscap = "capname"' with 'l4vmm,virtiocap = "capname"'
 * you may set 'capname' as a cap pointing to a virtio-block server.
 *
 * Notes about the CFI emulation.
 * - CFI operates in either read or write-mode
 * - in read-mode the flash acts like RAM and Linux will use the full
 *   instruction set to read it(*)
 * - in write mode Linux handles the device like an MMIO device
 * - read-mode is emulated by mapping the DS in read-only fashion to the guest
 * - if switched to write mode, we unmap the DS and emulate individual accesses
 * (*) Full emulation of read-mode would not be very performant and
 *     require a full instruction decoder, which --for x86-- we do not have
 *     and do not want.
 * - multiple chips emulated on the same bus (bank-width != device-width)
 *   are not independent: they must always receive the same commands
 */
class Cfi_flash
: public Vmm::Mmio_device,
  public Vdev::Device
{
  enum
  {
    Cmd_write_byte = 0x10,
    Cmd_block_erase = 0x20,
    Cmd_write_byte2 = 0x40,
    Cmd_clear_status = 0x50,
    Cmd_read_status = 0x70,
    Cmd_read_device_id = 0x90,
    Cmd_cfi_query = 0x98,
    Cmd_program_erase_suspend = 0xb0,
    Cmd_block_confirm = 0xd0,
    Cmd_write_block = 0xe8,
    Cmd_read_array = 0xff,

    Status_ready = 1 << 7,
    Status_erase_error = 1 << 5,
    Status_program_error = 1 << 4,

    Cfi_table_size = 0x40,
    Block_buffer_shift = 10, // 1 KiB
    Block_buffer_size = 1 << Block_buffer_shift,
  };

public:
  Cfi_flash(cxx::unique_ptr<Cfi_backend> be, l4_addr_t base, size_t size,
            size_t erase_size, bool ro, unsigned int bank_width,
            unsigned int device_width)
  : _be(cxx::move(be)), _base(base), _size(size), _erase_size(erase_size),
    _ro(ro), _bank_width(bank_width), _device_width(device_width)
  {
    unsigned int chip_shift = 8 * sizeof(unsigned int)
                              - __builtin_clz(bank_width / device_width) - 1;

    // Fill CFI table. See JESD6801...
    _cfi_table[0x10] = 'Q';
    _cfi_table[0x11] = 'R';
    _cfi_table[0x12] = 'Y';
    _cfi_table[0x13] = 0x01; // Intel command set
    _cfi_table[0x15] = 0x31; // Address of "PRI" below
    // Typical/maximum timeout for buffer write in 2^n
    // This must be set because all zero means "not supported"
    _cfi_table[0x20] = 1; // 2us
    _cfi_table[0x24] = 1; // 4us (2^1 multiplied by typical time above)
    _cfi_table[0x27] = 8 * sizeof(unsigned long) - __builtin_clzl(_size - 1U)
                       - chip_shift;
    // Block buffer size in 2^n (divided by number of chips)
    auto block_buf_shift = Block_buffer_shift - chip_shift;
    _cfi_table[0x2a] = cxx::min(block_buf_shift, device_width * 8);
    _cfi_table[0x2c] = 1; // one erase block region

    // Erase block region 1 (our only one)
    size_t num_blocks = (_size + erase_size - 1U) / erase_size;
    _cfi_table[0x2d] = num_blocks - 1U;
    _cfi_table[0x2e] = (num_blocks - 1U) >> 8;
    // Divide erase size by number of chips
    erase_size >>= chip_shift;
    _cfi_table[0x2f] = erase_size >> 8;
    _cfi_table[0x30] = erase_size >> 16;

    // Intel Primary Algorithm Extended Query Table
    _cfi_table[0x31] = 'P';
    _cfi_table[0x32] = 'R';
    _cfi_table[0x33] = 'I';
    _cfi_table[0x34] = '1';
    _cfi_table[0x35] = '0';

    info().printf("CFI flash (size %zu, %s, bank width: %u, device width: %u, erase size = %zu)\n",
                  _size, _ro ? "ro" : "rw", _bank_width, _device_width, _erase_size);
  }

  ~Cfi_flash()
  {}

  int access(l4_addr_t pfa, l4_addr_t offset, Vmm::Vcpu_ptr vcpu,
             L4::Cap<L4::Vm> vm_task, l4_addr_t, l4_addr_t) override
  {
    auto insn = vcpu.decode_mmio();

    if (insn.access == Vmm::Mem_access::Store)
      {
        write(vm_task, offset, insn.width, insn.value);
        return Vmm::Jump_instr;
      }
    else if (_cmd == Cmd_read_array)
      {
        map_mem_ro(vm_task);
        return Vmm::Retry;
      }
    else if (insn.access == Vmm::Mem_access::Load)
      {
        insn.value = read(offset, insn.width);
        vcpu.writeback_mmio(insn);
        return Vmm::Jump_instr;
      }
    else
      {
        warn().printf("MMIO access @ 0x%lx: unknown instruction. Ignored.\n",
                      pfa);
        return -L4_ENXIO;
      }
  }

  void map_eager(L4::Cap<L4::Vm>, Vmm::Guest_addr, Vmm::Guest_addr) override
  {}

  char const *dev_name() const override { return _be->dev_name(); }

private:
  void set_mode(L4::Cap<L4::Vm> vm_task, uint8_t cmd)
  {
    if (_cmd == Cmd_read_array && cmd != Cmd_read_array)
      unmap_mem(vm_task);

    _cmd = cmd;

    // Proactively map the flash memory, to avoid instruction decoding on reads.
    if (cmd == Cmd_read_array)
      {
        _be->write_back();
        map_mem_ro(vm_task);
      }
  }

  void map_mem_ro(L4::Cap<L4::Vm> vm_task)
  {
    auto local = reinterpret_cast<l4_addr_t>(local_addr());
    map_guest_range(vm_task, Vmm::Guest_addr(_base), local, _size, L4_FPAGE_RX);
  }

  void unmap_mem(L4::Cap<L4::Vm> vm_task)
  {
    unmap_guest_range(vm_task, Vmm::Guest_addr(_base), _size);
  }

  char *local_addr() const
  { return _be->local_addr(); }

  l4_size_t mapped_size() const
  { return _be->mapped_size(); }

  l4_umword_t device_mask() const
  { return ~0UL >> ((sizeof(l4_umword_t) - _device_width) * 8); }

  l4_umword_t chip_shift(l4_umword_t device_val, char size)
  {
    // Duplicate the device value shifted for the other chips on the same bus
    l4_umword_t val = 0;
    for (auto shift = 0U; shift < _bank_width; shift += _device_width)
      val |= device_val << (shift * 8);
    // Clear bits not visible for the access width
    return Vmm::Mem_access::read(val, 0, size);
  }

  bool check_chip_shift(l4_umword_t val)
  {
    auto device_val = val & device_mask();
    for (auto shift = _device_width; shift < _bank_width; shift += _device_width)
      {
        if (device_val != ((val >> (shift * 8)) & device_mask()))
          {
            warn().printf("Invalid command: 0x%lx, must be the same for all "
                          "chips\n", val);
            return false;
          }
      }
    return true;
  }

  l4_umword_t read(unsigned reg, char size)
  {
    if (reg + (1U << size) > _size)
      return -1;

    switch (_cmd)
      {
      case Cmd_read_array:
        {
          auto addr = reinterpret_cast<l4_addr_t>(local_addr() + reg);
          return Vmm::Mem_access::read_width(addr, size);
        }
      case Cmd_read_device_id:
        // Currently not implemented. Add once needed.
        return 0;
      case Cmd_cfi_query:
        {
          if (reg % _bank_width)
            {
              warn().printf("Unaligned read of CFI query: 0x%x\n", reg);
              return 0;
            }
          reg /= _bank_width;

          // Calculate number of elements to be read from the CFI query.
          // Multiple elements are read when the access width is larger than
          // the bank width. Rounding up is necessary in case a smaller access
          // width is used (e.g. 8-bit reads on a 32-bit flash).
          auto nregs = ((1U << size) + _bank_width - 1) / _bank_width;
          if (reg + nregs > sizeof(_cfi_table))
            return 0;

          // Fill the value using the _cfi_table...
          l4_umword_t val = 0;
          for (auto i = 0U; i < nregs; i++)
            val |= _cfi_table[reg + i] << (i * _bank_width * 8);
          // ... and duplicate it for all chips
          return chip_shift(val, size);
        }
      default:
        // read status
        return chip_shift(_status, size);
      }
  }

  void write(L4::Cap<L4::Vm> vm_task, unsigned reg, char size, l4_umword_t value)
  {
    if (reg + (1U << size) > _size)
      return;

    l4_uint8_t cmd = value;
    switch (_cmd)
      {
      case Cmd_write_byte:
      case Cmd_write_byte2:
        if (_ro)
          _status |= Status_program_error;
        else
          {
            auto addr = reinterpret_cast<l4_addr_t>(local_addr() + reg);
            auto before = Vmm::Mem_access::read_width(addr, size);
            Vmm::Mem_access::write_width(addr, before & value, size);
            _be->taint(reg, 1U << size);
          }
        _status |= Status_ready;
        set_mode(vm_task, Cmd_read_status);
        break;
      case Cmd_block_erase:
        if (!check_chip_shift(value))
          {
            _status |= Status_program_error | Status_erase_error;
            return;
          }
        switch (cmd)
          {
            case Cmd_block_confirm:
              _status |= Status_ready;
              if (_ro)
                _status |= Status_erase_error;
              else
                {
                  reg &= ~(_erase_size - 1U);
                  memset(local_addr() + reg, 0xff, _erase_size);
                  _be->taint(reg, _erase_size);
                  _be->write_back();
                }
              break;
            default:
              info().printf("Invalid command after Cmd_block_erase: 0x%02x\n",
                            cmd);
              _status |= Status_program_error | Status_erase_error;
              return;
          }
        set_mode(vm_task, Cmd_read_status);
        break;

      case Cmd_write_block:
        if (!write_block(vm_task, reg, size, value))
          {
            _status |= Status_program_error;
            set_mode(vm_task, Cmd_read_status);
          }
        break;

      case Cmd_read_status:
      case Cmd_read_device_id:
      case Cmd_cfi_query:
      case Cmd_read_array:
      case Cmd_program_erase_suspend:
        trace().printf("Command 0x%02x @ %u\n", cmd, reg);
        if (!check_chip_shift(value))
          return;
        switch (cmd)
          {
          case Cmd_clear_status:
            _status = 0;
            break;
          case Cmd_program_erase_suspend:
            _status |= Status_ready;
            [[fallthrough]];
          case Cmd_write_byte:
          case Cmd_write_byte2:
          case Cmd_block_erase:
          case Cmd_read_status:
          case Cmd_read_device_id:
          case Cmd_cfi_query:
          case Cmd_read_array:
            set_mode(vm_task, cmd);
            break;
          case Cmd_write_block:
            if (_ro)
              {
                _status |= Status_program_error;
                break;
              }
            _buf_len = 0;
            _status |= Status_ready;
            set_mode(vm_task, cmd);
            break;
          default:
            warn().printf("Unsupported command: %02x\n", cmd);
            break;
          }
        break;
      }
  }

  bool write_block(L4::Cap<L4::Vm> vm_task, unsigned reg, char size, l4_umword_t value)
  {
    if (!_buf_len)
      { // start of block write
        if (!check_chip_shift(value))
          return false;

        auto count = (value & device_mask()) + 1; // value = words - 1
        count *= _bank_width; // convert to bytes
        if (count > Block_buffer_size)
          {
            warn().printf("Invalid block write size: %lu val 0x%lx\n",
                          count, value);
            return false;
          }

        _buf_len = count;
        _buf_written = 0;
        return true;
      }

    if (!_buf_written)
      { // set start address on the first write
        trace().printf("Start block write at 0x%x with %u bytes\n",
                       reg, _buf_len);
        if (reg + _buf_len > _size)
          {
            warn().printf("Block write out of bounds: 0x%x + %u\n",
                          reg, _buf_len);
            return false;
          }
        // fill temporary buffer with original values
        // this is necessary because writes can only clear bits (bitwise AND)
        _buf_start = reg;
        memcpy(&_buffer, local_addr() + reg, _buf_len);
      }

    if (_buf_written >= _buf_len)
      { // all words written, write confirmed?
        if (!check_chip_shift(value))
          return false;

        trace().printf("Confirm buffer write with 0x%lx\n", value);

        if ((value & device_mask()) != Cmd_block_confirm)
          return false;

        // write back buffer
        memcpy(local_addr() + _buf_start, _buffer, _buf_len);
        _be->taint(_buf_start, _buf_len);
        _be->write_back();
        set_mode(vm_task, Cmd_read_status);
        return true;
      }

    if (_buf_start <= reg && (reg + (1 << size)) <= (_buf_start + _buf_len))
      { // write into buffer
        auto addr = reinterpret_cast<l4_addr_t>(&_buffer[reg - _buf_start]);
        auto before = Vmm::Mem_access::read_width(addr, size);
        Vmm::Mem_access::write_width(addr, before & value, size);
        _buf_written += 1 << size;
        return true;
      }

    // write out of bounds
    trace().printf("Out of bounds write to buffer; abort: 0x%x = 0x%lx\n",
                   reg, value);
    return false;
  }

  static Dbg info() { return Dbg(Dbg::Dev, Dbg::Info, "CFI"); }
  static Dbg warn() { return Dbg(Dbg::Dev, Dbg::Warn, "CFI"); }
  static Dbg trace() { return Dbg(Dbg::Dev, Dbg::Trace, "CFI"); }

  cxx::unique_ptr<Cfi_backend> _be;
  l4_addr_t _base;
  size_t _size, _erase_size;
  bool _ro;
  unsigned int _bank_width, _device_width;

  l4_uint8_t _cmd = Cmd_read_array;
  l4_uint8_t _status = 0;

  l4_uint8_t _cfi_table[Cfi_table_size] = { 0 };

  l4_uint8_t _buffer[Block_buffer_size];
  unsigned int _buf_start = 0;
  unsigned int _buf_len = 0;
  unsigned int _buf_written = 0;
};

struct F : Vdev::Factory
{
  cxx::Ref_ptr<Vdev::Device> create(Vdev::Device_lookup *devs,
                                    Vdev::Dt_node const &node) override
  {
    auto warn = Dbg(Dbg::Dev, Dbg::Warn, "CFI");
    l4_uint64_t base, size;
    int res = node.get_reg_val(0, &base, &size);
    if (res < 0)
      {
        warn.printf("Missing 'reg' property for node %s\n", node.get_name());
        return nullptr;
      }

    auto erase_size = fdt32_to_cpu(*node.check_prop<fdt32_t>("erase-size", 1));
    if (erase_size & (erase_size - 1))
      {
        warn.printf("erase-size must be a power of two: %u\n", erase_size);
        return nullptr;
      }

    if (size < erase_size || size % erase_size)
      {
        warn.printf("Wrong device size! Must be a multiple of erase block size.\n");
        return nullptr;
      }

    bool ro = node.has_prop("read-only");

    cxx::unique_ptr<Cfi_backend> be;

    auto dscap = Vdev::get_cap<L4Re::Dataspace>(node, "l4vmm,dscap");
    auto viocap = Vdev::get_cap<L4virtio::Device>(node, "l4vmm,virtiocap");

    if (dscap && viocap)
      warn.printf("Both dscap and virtiocap defined. Choosing virtiocap.\n");

    if (viocap)
      {
        try
          {
            be.reset(new Cfi_backend_virtio_block(viocap, size));
          }
        catch (L4::Runtime_error const &e)
          {
            warn.printf("Error in CFI virtio backend constructor %s: '%s'. "
                        "Disabling device.\n", e.str(), e.extra_str());
          }
      }
    else if (dscap)
      {
        if (!ro && !dscap->flags().w())
          {
            warn.printf(
              "DT configures flash to be writable, but dataspace is read-only. "
              "Defaulting to read-only operation.\n");
            ro = true;
          }

        if (size > dscap->size())
          {
            warn.printf(
              "Dataspace is too small for the CFI registers. "
              "This is not supported.\n");
          }
        else
          be.reset(new Cfi_backend_dataspace(dscap, size, ro));
      }

    if (!be)
      {
        warn.printf("Neither working 'l4vmm,dscap' nor 'l4vmm,virtiocap' "
                    "property!\n");
        return nullptr;
      }

    auto bank_width = fdt32_to_cpu(*node.check_prop<fdt32_t>("bank-width", 1));
    if (bank_width & (bank_width - 1) || bank_width > sizeof(l4_umword_t))
      {
        warn.printf("Invalid bank-width value: %u\n", bank_width);
        return nullptr;
      }

    int prop_size;
    auto prop = node.get_prop<fdt32_t>("device-width", &prop_size);
    auto device_width = bank_width;
    if (prop)
      {
        if (prop_size != 1)
          {
            warn.printf("Invalid device-width property size: %d\n", prop_size);
            return nullptr;
          }
        device_width = fdt32_to_cpu(*prop);
      }
    if (device_width & (device_width - 1) || device_width > bank_width)
      {
        warn.printf("Invalid device-width value: %u\n", device_width);
        return nullptr;
      }

    auto c = Vdev::make_device<Cfi_flash>(cxx::move(be), base, size, erase_size,
                                          ro, bank_width, device_width);
    devs->vmm()->register_mmio_device(c, Vmm::Region_type::Virtual, node);

    return c;
  }
};

}

static F f;
static Vdev::Device_type t = { "cfi-flash", nullptr, &f };
