Skip to content

Commit f15c599

Browse files
authored
Infer types from lance blobs (#7966)
infer types from lance blobs
1 parent 5b122f7 commit f15c599

File tree

4 files changed

+64
-6
lines changed

4 files changed

+64
-6
lines changed

src/datasets/features/audio.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str
252252
if pa.types.is_string(storage.type):
253253
bytes_array = pa.array([None] * len(storage), type=pa.binary())
254254
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
255+
elif pa.types.is_large_binary(storage.type):
256+
storage = array_cast(
257+
storage, pa.binary()
258+
) # this can fail in case of big audios, paths should be used instead
259+
path_array = pa.array([None] * len(storage), type=pa.string())
260+
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
255261
elif pa.types.is_binary(storage.type):
256262
path_array = pa.array([None] * len(storage), type=pa.string())
257263
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())

src/datasets/features/image.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
242242
if pa.types.is_string(storage.type):
243243
bytes_array = pa.array([None] * len(storage), type=pa.binary())
244244
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
245+
elif pa.types.is_large_binary(storage.type):
246+
storage = array_cast(
247+
storage, pa.binary()
248+
) # this can fail in case of big images, paths should be used instead
249+
path_array = pa.array([None] * len(storage), type=pa.string())
250+
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
245251
elif pa.types.is_binary(storage.type):
246252
path_array = pa.array([None] * len(storage), type=pa.string())
247253
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())

src/datasets/features/video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
258258
if pa.types.is_string(storage.type):
259259
bytes_array = pa.array([None] * len(storage), type=pa.binary())
260260
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
261+
elif pa.types.is_large_binary(storage.type):
262+
storage = array_cast(
263+
storage, pa.binary()
264+
) # this can fail in case of big videos, paths should be used instead
265+
path_array = pa.array([None] * len(storage), type=pa.string())
266+
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
261267
elif pa.types.is_binary(storage.type):
262268
path_array = pa.array([None] * len(storage), type=pa.string())
263269
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())

src/datasets/packaged_modules/lance/lance.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from huggingface_hub import HfApi
88

99
import datasets
10+
from datasets import Audio, Image, Video
1011
from datasets.builder import Key
1112
from datasets.table import table_cast
1213
from datasets.utils.file_utils import is_local_path
@@ -18,6 +19,23 @@
1819

1920
logger = datasets.utils.logging.get_logger(__name__)
2021

22+
MAGIC_BYTES_EXTENSION_AND_FEATURE_TYPES = [
23+
("1A 45 DF A3", ".mkv", Video()),
24+
("66 74 79 70 69 73 6F 6D", ".mp4", Video()),
25+
("66 74 79 70 4D 53 4E 56", ".mp4", Video()),
26+
("52 49 46 46", ".avi", Video()),
27+
("00 00 01 BA", ".mpeg", Video()),
28+
("00 00 01 BA", ".mpeg", Video()),
29+
("00 00 01 B3", ".mov", Video()),
30+
("89 50 4E 47", ".png", Image()),
31+
("FF D8", ".jpg", Image()),
32+
("49 49", ".tif", Image()),
33+
("47 49 46 38", ".gif", Image()),
34+
("52 49 46 46", ".wav", Audio()),
35+
("49 44 33", ".mp3", Audio()),
36+
("66 4C 61 43", ".flac", Audio()),
37+
]
38+
2139

2240
@dataclass
2341
class LanceConfig(datasets.BuilderConfig):
@@ -104,13 +122,23 @@ def _split_generators(self, dl_manager):
104122

105123
lance_dataset_uris = resolve_dataset_uris(files)
106124
if lance_dataset_uris:
107-
fragments = [
108-
frag
109-
for uri in lance_dataset_uris
110-
for frag in lance.dataset(uri, storage_options=storage_options).get_fragments()
111-
]
125+
lance_datasets = [lance.dataset(uri, storage_options=storage_options) for uri in lance_dataset_uris]
126+
fragments = [frag for lance_dataset in lance_datasets for frag in lance_dataset.get_fragments()]
112127
if self.info.features is None:
113128
pa_schema = fragments[0]._ds.schema
129+
first_row_first_bytes = {}
130+
for field in pa_schema:
131+
if self.config.columns is not None and field.name not in self.config.columns:
132+
continue
133+
if pa.types.is_binary(field.type) or pa.types.is_large_binary(field.type):
134+
try:
135+
first_row_first_bytes[field.name] = (
136+
lance_datasets[0].take_blobs(field.name, [0])[0].read(16)
137+
)
138+
except ValueError:
139+
first_row_first_bytes[field.name] = (
140+
lance_datasets[0].take([0], [field.name]).to_pylist()[0][field.name][:16]
141+
)
114142
splits.append(
115143
datasets.SplitGenerator(
116144
name=split_name,
@@ -124,6 +152,11 @@ def _split_generators(self, dl_manager):
124152
]
125153
if self.info.features is None:
126154
pa_schema = lance_files[0].metadata().schema
155+
first_row_first_bytes = {
156+
field_name: value[:16]
157+
for field_name, value in lance_files[0].take_rows([0]).to_table().to_pylist()[0].items()
158+
if isinstance(value, bytes)
159+
}
127160
splits.append(
128161
datasets.SplitGenerator(
129162
name=split_name,
@@ -136,7 +169,14 @@ def _split_generators(self, dl_manager):
136169
pa_schema.field(name) for name in self.config.columns if pa_schema.get_field_index(name) != -1
137170
]
138171
pa_schema = pa.schema(fields)
139-
self.info.features = datasets.Features.from_arrow_schema(pa_schema)
172+
features = datasets.Features.from_arrow_schema(pa_schema)
173+
for field_name, first_bytes in first_row_first_bytes.items():
174+
for magic_bytes_hex, _, feature_type in MAGIC_BYTES_EXTENSION_AND_FEATURE_TYPES:
175+
magic_bytes = bytes.fromhex(magic_bytes_hex)
176+
if magic_bytes in first_bytes[: len(magic_bytes) * 2]: # allow some padding
177+
features[field_name] = feature_type
178+
break
179+
self.info.features = features
140180

141181
return splits
142182

0 commit comments

Comments
 (0)