#include <linux/pci.h>

int pcidriver_pcie_get_mps(struct pci_dev *dev)
{
    u16 ctl;

    pcie_capability_read_word(dev, PCI_EXP_DEVCTL, &ctl);

    return 128 << ((ctl & PCI_EXP_DEVCTL_PAYLOAD) >> 5);
}

int pcidriver_pcie_set_mps(struct pci_dev *dev, int mps)
{
    u16 v;

    if (mps < 128 || mps > 4096 || !is_power_of_2(mps))
        return -EINVAL;

    v = ffs(mps) - 8;
    if (v > dev->pcie_mpss)
        return -EINVAL;
    v <<= 5;

    return pcie_capability_clear_and_set_word(dev, PCI_EXP_DEVCTL, PCI_EXP_DEVCTL_PAYLOAD, v);
}