@@ -7,7 +7,9 @@ from typing import List, Tuple, Optional, Dict
77import pyudev
88import sys
99import os
10+ import re
1011
12+ context = pyudev .Context ()
1113
1214def set_env_var (key , value ):
1315 print (f"{ key } ={ value } " )
@@ -51,21 +53,19 @@ def get_device_from_devpath(devpath: str) -> pyudev.Device:
5153 Given a DEVPATH (e.g. '/devices/.../tty/ttyUSB0'),
5254 return the corresponding pyudev Device.
5355 """
54- context = pyudev .Context ()
5556 sys_path = os .path .join ("/sys" , devpath .lstrip ("/" ))
5657 return pyudev .Devices .from_sys_path (context , sys_path )
5758
5859
59- def iter_kernel_ancestors (device : pyudev .Device , include_self : bool = False ):
60- """
61- Yield sys_name (kernel name) for the device and its parents.
62-
63- sys_name is what udev matches with KERNELS=="...".
64- """
65- dev = device if include_self else device .parent
66- while dev is not None :
67- yield dev
68- dev = dev .parent
60+ def get_pcieport_parent_sys_name (device : pyudev .Device ) -> Optional [str ]:
61+ try :
62+ pcieport_parent = next (
63+ d for d in device .ancestors if d .driver == "pcieport"
64+ )
65+ return pcieport_parent .sys_name
66+ except StopIteration :
67+ # no pcieport parent found
68+ return None
6969
7070
7171def main ():
@@ -86,22 +86,23 @@ def main():
8686 if not devpath :
8787 set_env_var ("CUSTOM_ALIAS_ERR" , "device path not provided!" )
8888 return
89-
89+
9090 device = get_device_from_devpath (devpath )
91-
92- try :
93- pcieport_parent = next (
94- d for d in iter_kernel_ancestors (device ) if d .driver == "pcieport"
95- )
96- except StopIteration :
97- # no pcieport parent found
98- return
99-
100- if pcieport_parent .sys_name not in alias_parents :
91+ pcieport_parent_sys_name = get_pcieport_parent_sys_name (device )
92+ if pcieport_parent_sys_name is None :
93+ # special case for NVMe:
94+ # try stripping down to nvme subsystem device
95+ nvme_block_name : str = device .sys_name
96+ m = re .match (r"(nvme\d+)n\d+" , nvme_block_name )
97+ if m :
98+ nvme_device = pyudev .Devices .from_name (context , "nvme" , m .group (1 ))
99+ pcieport_parent_sys_name = get_pcieport_parent_sys_name (nvme_device )
100+
101+ if pcieport_parent_sys_name not in alias_parents :
101102 # not mapped
102103 return
103104
104- alias = alias_parents [pcieport_parent . sys_name ]
105+ alias = alias_parents [pcieport_parent_sys_name ]
105106
106107 set_env_var ("ID_VDEV" , alias )
107108 set_env_var ("ID_VDEV_PATH" , os .path .join ("disk/by-vdev" , alias ))
0 commit comments