that's too much!
This commit is contained in:
207
.venv/lib/python3.12/site-packages/pyogrio/tests/test_arrow.py
Normal file
207
.venv/lib/python3.12/site-packages/pyogrio/tests/test_arrow.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from pyogrio import __gdal_version__, read_dataframe
|
||||
from pyogrio.raw import open_arrow, read_arrow
|
||||
from pyogrio.tests.conftest import requires_arrow_api
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
from pandas.testing import assert_frame_equal, assert_index_equal
|
||||
from geopandas.testing import assert_geodataframe_equal
|
||||
|
||||
import pyarrow
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# skip all tests in this file if Arrow API or GeoPandas are unavailable
|
||||
pytestmark = requires_arrow_api
|
||||
pytest.importorskip("geopandas")
|
||||
|
||||
|
||||
def test_read_arrow(naturalearth_lowres_all_ext):
|
||||
result = read_dataframe(naturalearth_lowres_all_ext, use_arrow=True)
|
||||
expected = read_dataframe(naturalearth_lowres_all_ext, use_arrow=False)
|
||||
|
||||
if naturalearth_lowres_all_ext.suffix.startswith(".geojson"):
|
||||
check_less_precise = True
|
||||
else:
|
||||
check_less_precise = False
|
||||
assert_geodataframe_equal(result, expected, check_less_precise=check_less_precise)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("skip_features, expected", [(10, 167), (200, 0)])
|
||||
def test_read_arrow_skip_features(naturalearth_lowres, skip_features, expected):
|
||||
table = read_arrow(naturalearth_lowres, skip_features=skip_features)[1]
|
||||
assert len(table) == expected
|
||||
|
||||
|
||||
def test_read_arrow_negative_skip_features(naturalearth_lowres):
|
||||
with pytest.raises(ValueError, match="'skip_features' must be >= 0"):
|
||||
read_arrow(naturalearth_lowres, skip_features=-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_features, expected", [(0, 0), (10, 10), (200, 177), (100000, 177)]
|
||||
)
|
||||
def test_read_arrow_max_features(naturalearth_lowres, max_features, expected):
|
||||
table = read_arrow(naturalearth_lowres, max_features=max_features)[1]
|
||||
assert len(table) == expected
|
||||
|
||||
|
||||
def test_read_arrow_negative_max_features(naturalearth_lowres):
|
||||
with pytest.raises(ValueError, match="'max_features' must be >= 0"):
|
||||
read_arrow(naturalearth_lowres, max_features=-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_features, max_features, expected",
|
||||
[
|
||||
(0, 0, 0),
|
||||
(10, 0, 0),
|
||||
(200, 0, 0),
|
||||
(1, 200, 176),
|
||||
(176, 10, 1),
|
||||
(100, 100, 77),
|
||||
(100, 100000, 77),
|
||||
],
|
||||
)
|
||||
def test_read_arrow_skip_features_max_features(
|
||||
naturalearth_lowres, skip_features, max_features, expected
|
||||
):
|
||||
table = read_arrow(
|
||||
naturalearth_lowres, skip_features=skip_features, max_features=max_features
|
||||
)[1]
|
||||
assert len(table) == expected
|
||||
|
||||
|
||||
def test_read_arrow_fid(naturalearth_lowres_all_ext):
|
||||
kwargs = {"use_arrow": True, "where": "fid >= 2 AND fid <= 3"}
|
||||
|
||||
df = read_dataframe(naturalearth_lowres_all_ext, fid_as_index=False, **kwargs)
|
||||
assert_index_equal(df.index, pd.RangeIndex(0, 2))
|
||||
|
||||
df = read_dataframe(naturalearth_lowres_all_ext, fid_as_index=True, **kwargs)
|
||||
assert_index_equal(df.index, pd.Index([2, 3], name="fid"))
|
||||
|
||||
|
||||
def test_read_arrow_columns(naturalearth_lowres):
|
||||
result = read_dataframe(naturalearth_lowres, use_arrow=True, columns=["continent"])
|
||||
assert result.columns.tolist() == ["continent", "geometry"]
|
||||
|
||||
|
||||
def test_read_arrow_ignore_geometry(naturalearth_lowres):
|
||||
result = read_dataframe(naturalearth_lowres, use_arrow=True, read_geometry=False)
|
||||
assert type(result) is pd.DataFrame
|
||||
|
||||
expected = read_dataframe(naturalearth_lowres, use_arrow=True).drop(
|
||||
columns=["geometry"]
|
||||
)
|
||||
assert_frame_equal(result, expected)
|
||||
|
||||
|
||||
def test_read_arrow_nested_types(test_ogr_types_list):
|
||||
# with arrow, list types are supported
|
||||
result = read_dataframe(test_ogr_types_list, use_arrow=True)
|
||||
assert "list_int64" in result.columns
|
||||
assert result["list_int64"][0].tolist() == [0, 1]
|
||||
|
||||
|
||||
def test_read_arrow_to_pandas_kwargs(test_fgdb_vsi):
|
||||
# with arrow, list types are supported
|
||||
arrow_to_pandas_kwargs = {"strings_to_categorical": True}
|
||||
result = read_dataframe(
|
||||
test_fgdb_vsi,
|
||||
use_arrow=True,
|
||||
arrow_to_pandas_kwargs=arrow_to_pandas_kwargs,
|
||||
)
|
||||
assert "SEGMENT_NAME" in result.columns
|
||||
assert result["SEGMENT_NAME"].dtype.name == "category"
|
||||
|
||||
|
||||
def test_read_arrow_raw(naturalearth_lowres):
|
||||
meta, table = read_arrow(naturalearth_lowres)
|
||||
assert isinstance(meta, dict)
|
||||
assert isinstance(table, pyarrow.Table)
|
||||
|
||||
|
||||
def test_open_arrow(naturalearth_lowres):
|
||||
with open_arrow(naturalearth_lowres) as (meta, reader):
|
||||
assert isinstance(meta, dict)
|
||||
assert isinstance(reader, pyarrow.RecordBatchReader)
|
||||
assert isinstance(reader.read_all(), pyarrow.Table)
|
||||
|
||||
|
||||
def test_open_arrow_batch_size(naturalearth_lowres):
|
||||
meta, table = read_arrow(naturalearth_lowres)
|
||||
batch_size = math.ceil(len(table) / 2)
|
||||
|
||||
with open_arrow(naturalearth_lowres, batch_size=batch_size) as (meta, reader):
|
||||
assert isinstance(meta, dict)
|
||||
assert isinstance(reader, pyarrow.RecordBatchReader)
|
||||
count = 0
|
||||
tables = []
|
||||
for table in reader:
|
||||
tables.append(table)
|
||||
count += 1
|
||||
|
||||
assert count == 2, "Should be two batches given the batch_size parameter"
|
||||
assert len(tables[0]) == batch_size, "First table should match the batch size"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
__gdal_version__ >= (3, 8, 0),
|
||||
reason="skip_features supported by Arrow stream API for GDAL>=3.8.0",
|
||||
)
|
||||
@pytest.mark.parametrize("skip_features", [10, 200])
|
||||
def test_open_arrow_skip_features_unsupported(naturalearth_lowres, skip_features):
|
||||
"""skip_features are not supported for the Arrow stream interface for
|
||||
GDAL < 3.8.0"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="specifying 'skip_features' is not supported for Arrow for GDAL<3.8.0",
|
||||
):
|
||||
with open_arrow(naturalearth_lowres, skip_features=skip_features) as (
|
||||
meta,
|
||||
reader,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_features", [10, 200])
|
||||
def test_open_arrow_max_features_unsupported(naturalearth_lowres, max_features):
|
||||
"""max_features are not supported for the Arrow stream interface"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="specifying 'max_features' is not supported for Arrow",
|
||||
):
|
||||
with open_arrow(naturalearth_lowres, max_features=max_features) as (
|
||||
meta,
|
||||
reader,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_arrow_context():
|
||||
original = os.environ.get("PYOGRIO_USE_ARROW", None)
|
||||
os.environ["PYOGRIO_USE_ARROW"] = "1"
|
||||
yield
|
||||
if original:
|
||||
os.environ["PYOGRIO_USE_ARROW"] = original
|
||||
else:
|
||||
del os.environ["PYOGRIO_USE_ARROW"]
|
||||
|
||||
|
||||
def test_enable_with_environment_variable(test_ogr_types_list):
|
||||
# list types are only supported with arrow, so don't work by default and work
|
||||
# when arrow is enabled through env variable
|
||||
result = read_dataframe(test_ogr_types_list)
|
||||
assert "list_int64" not in result.columns
|
||||
|
||||
with use_arrow_context():
|
||||
result = read_dataframe(test_ogr_types_list)
|
||||
assert "list_int64" in result.columns
|
||||
Reference in New Issue
Block a user