Pyarrow 基本身份验证:如何防止`Stream is closed`?

Pyarrow basic auth: How to prevent `Stream is closed`?

我是 Arrow Flight 和 pyarrow (v=6.0.1) 的新手,我正在尝试实现基本身份验证,但我总是遇到错误:

OSError: Stream is closed

我通过运行顺序创建了一个最小的复制样本(分别代表服务器和客户端):

from typing import Dict, Union
from pyarrow.lib import tobytes
from pyarrow.flight import BasicAuth, FlightUnauthenticatedError, ServerAuthHandler, FlightServerBase
from pyarrow._flight import ServerAuthSender, ServerAuthReader


class ServerBasicAuthHandler(ServerAuthHandler):
    def __init__(self, creds: Dict[str, str]):
        self.creds = {user.encode(): pw.encode() for user, pw in creds.items()}

    def authenticate(self, outgoing: ServerAuthSender, incoming: ServerAuthReader):
        buf = incoming.read()  # this line raises "OSError: Stream is closed"
        auth = BasicAuth.deserialize(buf)
        if auth.username not in self.creds:
            raise FlightUnauthenticatedError("unknown user")
        if self.creds[auth.username] != auth.password:
            raise FlightUnauthenticatedError("wrong password")
        outgoing.write(tobytes(auth.username))

    def is_valid(self, token: bytes) -> Union[bytes, str]:
        if not token:
            raise FlightUnauthenticatedError("no basic auth provided")
        if token not in self.creds:
            raise FlightUnauthenticatedError("unknown user")
        return token

service = FlightServerBase(
    location=f"grpc://[::]:50051",
    auth_handler=ServerBasicAuthHandler({"user": "pw"}),
)

service.serve()
from pyarrow.flight import FlightClient

client = FlightClient(location=f"grpc://localhost:50051")
client.authenticate_basic_token("user", "pw")

我基本上是从 their tests 复制了 ServerAuthHandler 实现,所以它被证明是有效的。但是,我无法让它工作。

错误消息 Stream is closed 难以调试。我不知道它来自哪里,我无法将它追溯到 pyarrow 实现中的任何地方(Python 端和 C++ 端都不是)。我看不出它来自哪里。

任何有关如何防止此错误的帮助或提示将不胜感激。

我认为这只是因为 Windows 不支持此功能。

经过仔细检查,Windows 中跳过了“证明它有效”的测试。评论指的是this issue。这个问题已经解决了(表面上);不知道为什么它不能与 Stream is closed.

一起使用

OP 中的示例混淆了两个身份验证实现(这确实令人困惑)。 “BasicAuth”对象不是 authenticate_basic_token 方法实现的实际 HTTP 基本身份验证;这是因为多年来,贡献者已经实施了多种身份验证方法。实际测试如下:

header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
no_op_auth_handler = NoopAuthHandler()


def test_authenticate_basic_token():
    """Test authenticate_basic_token with bearer token and auth headers."""
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server:
        client = FlightClient(('localhost', server.port))
        token_pair = client.authenticate_basic_token(b'test', b'password')
        assert token_pair[0] == b'authorization'
        assert token_pair[1] == b'Bearer token1234'

即我们没有使用 authenticate 而是使用“中间件”来实现。完整示例如下所示:

import base64
import pyarrow.flight as flight

class BasicAuthServerMiddlewareFactory(flight.ServerMiddlewareFactory):
    def __init__(self, creds):
        self.creds = creds

    def start_call(self, info, headers):
        token = None
        for header in headers:
            if header.lower() == "authorization":
                token = headers[header]
                break

        if not token:
            raise flight.FlightUnauthenticatedError("No credentials supplied")

        values = token[0].split(' ', 1)
        if values[0] == 'Basic':
            decoded = base64.b64decode(values[1])
            pair = decoded.decode("utf-8").split(':')
            if pair[0] not in self.creds:
                raise flight.FlightUnauthenticatedError("No credentials supplied")
            if pair[1] != self.creds[pair[0]]:
                raise flight.FlightUnauthenticatedError("No credentials supplied")
            return BasicAuthServerMiddleware("BearerTokenValue")

        raise flight.FlightUnauthenticatedError("No credentials supplied")


class BasicAuthServerMiddleware(flight.ServerMiddleware):
    def __init__(self, token):
        self.token = token

    def sending_headers(self):
        return {'authorization': f'Bearer {self.token}'}


class NoOpAuthHandler(flight.ServerAuthHandler):
    def authenticate(self, outgoing, incoming):
        pass

    def is_valid(self, token):
        return ""


with flight.FlightServerBase(auth_handler=NoOpAuthHandler(), middleware={
    "basic": BasicAuthServerMiddlewareFactory({"test": "password"})
}) as server:
    client = flight.connect(('localhost', server.port))
    token_pair = client.authenticate_basic_token(b'test', b'password')
    print(token_pair)
    assert token_pair[0] == b'authorization'
    assert token_pair[1] == b'Bearer BearerTokenValue'