diff --git a/drivers/coda/coda.c b/drivers/coda/coda.c index 965d1ec4877e908ff0d07cfe8f4070a0bf8a503d..be2f4adca9c26e5943193305f05b30a1aaee449e 100644 --- a/drivers/coda/coda.c +++ b/drivers/coda/coda.c @@ -72,9 +72,10 @@ static int get_root_bd(struct device *dev) * @devs: All child devices under input dev * @max_devs: Max num of devs * @ndev: Num of child devices + * @pdevs: All child pci_dev under input dev */ static void get_child_devices_rec(struct pci_dev *dev, uint16_t *devs, - int max_devs, int *ndev) + int max_devs, int *ndev, struct pci_dev **pdevs) { struct pci_bus *bus = dev->subordinate; @@ -82,7 +83,7 @@ static void get_child_devices_rec(struct pci_dev *dev, uint16_t *devs, struct pci_dev *child; list_for_each_entry(child, &bus->devices, bus_list) { - get_child_devices_rec(child, devs, max_devs, ndev); + get_child_devices_rec(child, devs, max_devs, ndev, pdevs); } } else { /* dev is a regular device */ uint16_t bdf = pci_dev_id(dev); @@ -98,6 +99,7 @@ static void get_child_devices_rec(struct pci_dev *dev, uint16_t *devs, return; } devs[*ndev] = bdf; + pdevs[*ndev] = dev; *ndev = *ndev + 1; } } @@ -106,12 +108,14 @@ static void get_child_devices_rec(struct pci_dev *dev, uint16_t *devs, * get_sibling_devices - Get all devices which share the same root_bd as dev * @dev: Device for which to get child devices * @devs: All child devices under input dev + * @pdevs: All child pci_dev under input dev * @max_devs: Max num of devs * * Returns: * %0 if get child devices failure */ -static int get_sibling_devices(struct device *dev, uint16_t *devs, int max_devs) +static int get_sibling_devices(struct device *dev, uint16_t *devs, + struct pci_dev **pdevs, int max_devs) { struct pci_dev *pdev; int ndev = 0; @@ -130,6 +134,7 @@ static int get_sibling_devices(struct device *dev, uint16_t *devs, int max_devs) */ if (pdev->is_virtfn) { devs[ndev] = pci_dev_id(pdev); + pdevs[ndev] = pdev; ndev = ndev + 1; pdev = pci_physfn(pdev); } @@ -137,7 +142,7 @@ static int get_sibling_devices(struct device *dev, uint16_t *devs, int max_devs) while (!pci_is_root_bus(pdev->bus)) pdev = pci_upstream_bridge(pdev); - get_child_devices_rec(pdev, devs, max_devs, &ndev); + get_child_devices_rec(pdev, devs, max_devs, &ndev, pdevs); return ndev; } @@ -776,18 +781,20 @@ static bool virtcca_check_dev_is_assigned_to_nvm(struct tmi_dev_delegate_params * virtcca_get_all_cc_dev_info - Retrieve all devices under the root port * @dev: CC device * @params: Delegate device parameters + * @pdevs: CC pci_dev * * Returns: * %0 if get all devices under the root port successful * %-EINVAL if the total number of devices under the root port exceeds the maximum */ -static int virtcca_get_all_cc_dev_info(struct device *dev, struct tmi_dev_delegate_params *params) +static int virtcca_get_all_cc_dev_info(struct device *dev, + struct tmi_dev_delegate_params *params, struct pci_dev **pdevs) { int ret = 0; uint16_t root_bd = get_root_bd(dev); params->root_bd = root_bd; - params->num_dev = get_sibling_devices(dev, params->devs, MAX_DEV_PER_PORT); + params->num_dev = get_sibling_devices(dev, params->devs, pdevs, MAX_DEV_PER_PORT); if (params->num_dev >= MAX_DEV_PER_PORT) { ret = -EINVAL; pr_err("virtcca_get_all_cc_dev_info nums overflow\n"); @@ -811,6 +818,7 @@ static void virtcca_destroy_devices(struct tmi_dev_delegate_params *params) /** * virtcca_create_cc_dev_ste - Traverse the devices under the root port and set the secure SMMU STE table for them * @smmu: An SMMUv3 instance + * @pdevs: CC pci_dev * @dev: CC device * @params: Delegate device parameters * @s2vmid: SMMU STE s2vmid @@ -819,7 +827,7 @@ static void virtcca_destroy_devices(struct tmi_dev_delegate_params *params) * %0 if set STE success * %-EINVAL set STE config content failed or does not find corresponding master info */ -static int virtcca_create_cc_dev_ste(struct arm_smmu_device *smmu, +static int virtcca_create_cc_dev_ste(struct arm_smmu_device *smmu, struct pci_dev **pdevs, struct device *dev, struct tmi_dev_delegate_params *params, uint16_t *s2vmid) { int ret = 0; @@ -827,7 +835,7 @@ static int virtcca_create_cc_dev_ste(struct arm_smmu_device *smmu, if (!is_cc_root_bd(root_bd)) { /* Get all devices information under the same root port */ - ret = virtcca_get_all_cc_dev_info(dev, params); + ret = virtcca_get_all_cc_dev_info(dev, params, pdevs); if (ret) return ret; @@ -850,6 +858,22 @@ static int virtcca_create_cc_dev_ste(struct arm_smmu_device *smmu, return ret; } +/* Unbind the VF's driver before switching to secure state */ +static int virtcca_unbind_vf_driver(struct pci_dev **pdevs, uint16_t nr) +{ + for (uint16_t i = 0; i < nr; i++) { + if (!pdevs[i]) + return -EINVAL; + /* Only unbind the VF with its own driver */ + if (pdevs[i] == pci_physfn(pdevs[i]) || + !pdevs[i]->driver || + !strcmp(pdevs[i]->driver->name, "vfio-pci")) + continue; + device_release_driver(&pdevs[i]->dev); + } + return 0; +} + /** * virtcca_attach_each_dev_to_cvm - Attach each device under the same group to cvm, * attach device includes setting STE and enabling PCIPC. @@ -873,6 +897,7 @@ static int virtcca_attach_each_dev_to_cvm(struct device *dev, void *domain) struct arm_smmu_domain *smmu_domain = NULL; struct arm_smmu_master *master = NULL; struct tmi_dev_delegate_params *params = NULL; + struct pci_dev *pdevs[MAX_DEV_PER_PORT] = {0}; if (!is_virtcca_cvm_enable()) return 0; @@ -907,7 +932,12 @@ static int virtcca_attach_each_dev_to_cvm(struct device *dev, void *domain) * Obtain device information under the root port * and set security SMMU STE table for it */ - ret = virtcca_create_cc_dev_ste(smmu, dev, params, s2vmid); + ret = virtcca_create_cc_dev_ste(smmu, pdevs, dev, params, s2vmid); + if (ret) + goto out; + + if (!is_cc_root_bd(get_root_bd(dev))) + ret = virtcca_unbind_vf_driver(pdevs, params->num_dev); if (ret) goto out; @@ -1163,4 +1193,4 @@ int virtcca_vdev_create(struct pci_dev *pci_dev) } return ret; } -EXPORT_SYMBOL_GPL(virtcca_vdev_create); \ No newline at end of file +EXPORT_SYMBOL_GPL(virtcca_vdev_create);