208 lines
6.8 KiB
Python
208 lines
6.8 KiB
Python
import contextlib
|
|
import math
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from pyogrio import __gdal_version__, read_dataframe
|
|
from pyogrio.raw import open_arrow, read_arrow
|
|
from pyogrio.tests.conftest import requires_arrow_api
|
|
|
|
try:
|
|
import pandas as pd
|
|
from pandas.testing import assert_frame_equal, assert_index_equal
|
|
from geopandas.testing import assert_geodataframe_equal
|
|
|
|
import pyarrow
|
|
except ImportError:
|
|
pass
|
|
|
|
# skip all tests in this file if Arrow API or GeoPandas are unavailable
|
|
pytestmark = requires_arrow_api
|
|
pytest.importorskip("geopandas")
|
|
|
|
|
|
def test_read_arrow(naturalearth_lowres_all_ext):
|
|
result = read_dataframe(naturalearth_lowres_all_ext, use_arrow=True)
|
|
expected = read_dataframe(naturalearth_lowres_all_ext, use_arrow=False)
|
|
|
|
if naturalearth_lowres_all_ext.suffix.startswith(".geojson"):
|
|
check_less_precise = True
|
|
else:
|
|
check_less_precise = False
|
|
assert_geodataframe_equal(result, expected, check_less_precise=check_less_precise)
|
|
|
|
|
|
@pytest.mark.parametrize("skip_features, expected", [(10, 167), (200, 0)])
|
|
def test_read_arrow_skip_features(naturalearth_lowres, skip_features, expected):
|
|
table = read_arrow(naturalearth_lowres, skip_features=skip_features)[1]
|
|
assert len(table) == expected
|
|
|
|
|
|
def test_read_arrow_negative_skip_features(naturalearth_lowres):
|
|
with pytest.raises(ValueError, match="'skip_features' must be >= 0"):
|
|
read_arrow(naturalearth_lowres, skip_features=-1)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"max_features, expected", [(0, 0), (10, 10), (200, 177), (100000, 177)]
|
|
)
|
|
def test_read_arrow_max_features(naturalearth_lowres, max_features, expected):
|
|
table = read_arrow(naturalearth_lowres, max_features=max_features)[1]
|
|
assert len(table) == expected
|
|
|
|
|
|
def test_read_arrow_negative_max_features(naturalearth_lowres):
|
|
with pytest.raises(ValueError, match="'max_features' must be >= 0"):
|
|
read_arrow(naturalearth_lowres, max_features=-1)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"skip_features, max_features, expected",
|
|
[
|
|
(0, 0, 0),
|
|
(10, 0, 0),
|
|
(200, 0, 0),
|
|
(1, 200, 176),
|
|
(176, 10, 1),
|
|
(100, 100, 77),
|
|
(100, 100000, 77),
|
|
],
|
|
)
|
|
def test_read_arrow_skip_features_max_features(
|
|
naturalearth_lowres, skip_features, max_features, expected
|
|
):
|
|
table = read_arrow(
|
|
naturalearth_lowres, skip_features=skip_features, max_features=max_features
|
|
)[1]
|
|
assert len(table) == expected
|
|
|
|
|
|
def test_read_arrow_fid(naturalearth_lowres_all_ext):
|
|
kwargs = {"use_arrow": True, "where": "fid >= 2 AND fid <= 3"}
|
|
|
|
df = read_dataframe(naturalearth_lowres_all_ext, fid_as_index=False, **kwargs)
|
|
assert_index_equal(df.index, pd.RangeIndex(0, 2))
|
|
|
|
df = read_dataframe(naturalearth_lowres_all_ext, fid_as_index=True, **kwargs)
|
|
assert_index_equal(df.index, pd.Index([2, 3], name="fid"))
|
|
|
|
|
|
def test_read_arrow_columns(naturalearth_lowres):
|
|
result = read_dataframe(naturalearth_lowres, use_arrow=True, columns=["continent"])
|
|
assert result.columns.tolist() == ["continent", "geometry"]
|
|
|
|
|
|
def test_read_arrow_ignore_geometry(naturalearth_lowres):
|
|
result = read_dataframe(naturalearth_lowres, use_arrow=True, read_geometry=False)
|
|
assert type(result) is pd.DataFrame
|
|
|
|
expected = read_dataframe(naturalearth_lowres, use_arrow=True).drop(
|
|
columns=["geometry"]
|
|
)
|
|
assert_frame_equal(result, expected)
|
|
|
|
|
|
def test_read_arrow_nested_types(test_ogr_types_list):
|
|
# with arrow, list types are supported
|
|
result = read_dataframe(test_ogr_types_list, use_arrow=True)
|
|
assert "list_int64" in result.columns
|
|
assert result["list_int64"][0].tolist() == [0, 1]
|
|
|
|
|
|
def test_read_arrow_to_pandas_kwargs(test_fgdb_vsi):
|
|
# with arrow, list types are supported
|
|
arrow_to_pandas_kwargs = {"strings_to_categorical": True}
|
|
result = read_dataframe(
|
|
test_fgdb_vsi,
|
|
use_arrow=True,
|
|
arrow_to_pandas_kwargs=arrow_to_pandas_kwargs,
|
|
)
|
|
assert "SEGMENT_NAME" in result.columns
|
|
assert result["SEGMENT_NAME"].dtype.name == "category"
|
|
|
|
|
|
def test_read_arrow_raw(naturalearth_lowres):
|
|
meta, table = read_arrow(naturalearth_lowres)
|
|
assert isinstance(meta, dict)
|
|
assert isinstance(table, pyarrow.Table)
|
|
|
|
|
|
def test_open_arrow(naturalearth_lowres):
|
|
with open_arrow(naturalearth_lowres) as (meta, reader):
|
|
assert isinstance(meta, dict)
|
|
assert isinstance(reader, pyarrow.RecordBatchReader)
|
|
assert isinstance(reader.read_all(), pyarrow.Table)
|
|
|
|
|
|
def test_open_arrow_batch_size(naturalearth_lowres):
|
|
meta, table = read_arrow(naturalearth_lowres)
|
|
batch_size = math.ceil(len(table) / 2)
|
|
|
|
with open_arrow(naturalearth_lowres, batch_size=batch_size) as (meta, reader):
|
|
assert isinstance(meta, dict)
|
|
assert isinstance(reader, pyarrow.RecordBatchReader)
|
|
count = 0
|
|
tables = []
|
|
for table in reader:
|
|
tables.append(table)
|
|
count += 1
|
|
|
|
assert count == 2, "Should be two batches given the batch_size parameter"
|
|
assert len(tables[0]) == batch_size, "First table should match the batch size"
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
__gdal_version__ >= (3, 8, 0),
|
|
reason="skip_features supported by Arrow stream API for GDAL>=3.8.0",
|
|
)
|
|
@pytest.mark.parametrize("skip_features", [10, 200])
|
|
def test_open_arrow_skip_features_unsupported(naturalearth_lowres, skip_features):
|
|
"""skip_features are not supported for the Arrow stream interface for
|
|
GDAL < 3.8.0"""
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0",
|
|
):
|
|
with open_arrow(naturalearth_lowres, skip_features=skip_features) as (
|
|
meta,
|
|
reader,
|
|
):
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize("max_features", [10, 200])
|
|
def test_open_arrow_max_features_unsupported(naturalearth_lowres, max_features):
|
|
"""max_features are not supported for the Arrow stream interface"""
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="specifying 'max_features' is not supported for Arrow",
|
|
):
|
|
with open_arrow(naturalearth_lowres, max_features=max_features) as (
|
|
meta,
|
|
reader,
|
|
):
|
|
pass
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def use_arrow_context():
|
|
original = os.environ.get("PYOGRIO_USE_ARROW", None)
|
|
os.environ["PYOGRIO_USE_ARROW"] = "1"
|
|
yield
|
|
if original:
|
|
os.environ["PYOGRIO_USE_ARROW"] = original
|
|
else:
|
|
del os.environ["PYOGRIO_USE_ARROW"]
|
|
|
|
|
|
def test_enable_with_environment_variable(test_ogr_types_list):
|
|
# list types are only supported with arrow, so don't work by default and work
|
|
# when arrow is enabled through env variable
|
|
result = read_dataframe(test_ogr_types_list)
|
|
assert "list_int64" not in result.columns
|
|
|
|
with use_arrow_context():
|
|
result = read_dataframe(test_ogr_types_list)
|
|
assert "list_int64" in result.columns
|