如何通过 Huggingface 将分类与零样本分类并行化?
How to parallelize classification with Zero Shot Classification by Huggingface?
我有大约 70 个类别(也可以是 20 或 30 个),我希望能够使用 ray 并行处理该过程,但出现错误:
import pandas as pd
import swifter
import json
import ray
from transformers import pipeline
classifier = pipeline("zero-shot-classification")
labels = ["vegetables", "potato", "bell pepper", "tomato", "onion", "carrot", "broccoli",
"lettuce", "cucumber", "celery", "corn", "garlic", "mashrooms", "cabbage", "spinach",
"beans", "cauliflower", "asparagus", "fruits", "bananas", "apples", "strawberries",
"grapes", "oranges", "lemons", "avocados", "peaches", "blueberries", "pineapple",
"cherries", "pears", "mangoe", "berries", "red meat", "beef", "pork", "mutton",
"veal", "lamb", "venison", "goat", "mince", "white meat", "chicken", "turkey",
"duck", "goose", "pheasant", "rabbit", "Processed meat", "sausages", "bacon",
"ham", "hot dogs", "frankfurters", "tinned meat", "salami", "pâtés", "beef jerky",
"chorizo", "pepperoni", "corned beef", "fish", "catfish", "cod", "pangasius", "pollock",
"tilapia", "tuna", "salmon", "seafood", "shrimp", "squid", "mussels", "scallop",
"octopus", "grains", "rice", "wheat", "bulgur", "corn", "oat", "quinoa", "buckwheat",
"meals", "salad", "soup", "steak", "pizza", "pie", "burger", "backery", "bread", "souce",
"pasta", "sandwich", "waffles", "barbecue", "roll", "wings", "ribs", "cookies"]
ray.init()
@ray.remote
def get_meal_category(seq, labels, n=3):
res_dict = classifier(seq, labels)
return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))
res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
其中 merged_df 是一个大数据框,其标签列中包含膳食名称,例如:
['Cappuccino',
'Stove Top Stuffing Mix For Turkey (Kraft)',
'Stove Top Stuffing Mix For Turkey (Kraft)',
'Roasted Dark Turkey Meat',
'Roasted Dark Turkey Meat',
'Roasted Dark Turkey Meat',
'Cappuccino',
'Low Fat 2% Small Curd Cottage Cheese (Daisy)',
'Rice Cereal (Gerber)',
'Oranges']
请指教如何避免ray的错误和并行化分类。
错误:
2021-02-17 16:54:51,689 WARNING worker.py:1107 -- Warning: The remote function __main__.get_meal_category has size 1630925709 when pickled. It will be stored in Redis, which could cause memory issues. This may mean that its definition uses a large array or other object.
---------------------------------------------------------------------------
ConnectionResetError Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
705 for item in command:
--> 706 sendall(self._sock, item)
707 except socket.timeout:
~/.local/lib/python3.8/site-packages/redis/_compat.py in sendall(sock, *args, **kwargs)
8 def sendall(sock, *args, **kwargs):
----> 9 return sock.sendall(*args, **kwargs)
10
ConnectionResetError: [Errno 104] Connection reset by peer
During handling of the above exception, another exception occurred:
ConnectionError Traceback (most recent call last)
<ipython-input-9-1a5345832fba> in <module>
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
<ipython-input-9-1a5345832fba> in <listcomp>(.0)
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote_proxy(*args, **kwargs)
99 @wraps(function)
100 def _remote_proxy(*args, **kwargs):
--> 101 return self._remote(args=args, kwargs=kwargs)
102
103 self.remote = _remote_proxy
~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote(self, args, kwargs, num_returns, num_cpus, num_gpus, memory, object_store_memory, accelerator_type, resources, max_retries, placement_group, placement_group_bundle_index, placement_group_capture_child_tasks, override_environment_variables, name)
205
206 self._last_export_session_and_job = worker.current_session_and_job
--> 207 worker.function_actor_manager.export(self)
208
209 kwargs = {} if kwargs is None else kwargs
~/.local/lib/python3.8/site-packages/ray/function_manager.py in export(self, remote_function)
142 key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
143 + remote_function._function_descriptor.function_id.binary())
--> 144 self._worker.redis_client.hset(
145 key,
146 mapping={
~/.local/lib/python3.8/site-packages/redis/client.py in hset(self, name, key, value, mapping)
3048 items.extend(pair)
3049
-> 3050 return self.execute_command('HSET', name, *items)
3051
3052 def hsetnx(self, name, key, value):
~/.local/lib/python3.8/site-packages/redis/client.py in execute_command(self, *args, **options)
898 conn = self.connection or pool.get_connection(command_name, **options)
899 try:
--> 900 conn.send_command(*args)
901 return self.parse_response(conn, command_name, **options)
902 except (ConnectionError, TimeoutError) as e:
~/.local/lib/python3.8/site-packages/redis/connection.py in send_command(self, *args, **kwargs)
723 def send_command(self, *args, **kwargs):
724 "Pack and send a command to the Redis server"
--> 725 self.send_packed_command(self.pack_command(*args),
726 check_health=kwargs.get('check_health', True))
727
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
715 errno = e.args[0]
716 errmsg = e.args[1]
--> 717 raise ConnectionError("Error %s while writing to socket. %s." %
718 (errno, errmsg))
719 except BaseException:
ConnectionError: Error 104 while writing to socket. Connection reset by peer.
由于将大型对象发送到 Redis,因此发生此错误。 merged_df
是一个大数据帧,由于您调用 get_meal_category
10 次,Ray 将尝试序列化 merged_df
10 次。相反,如果您将 merged_df
放入 Ray 对象存储中一次,然后传递对该对象的引用,这应该可以工作。
编辑:由于分类器也很大,因此也做类似的事情。
你能试试这样吗:
ray.init()
df_ref = ray.put(merged_df)
model_ref = ray.put(classifier)
@ray.remote
def get_meal_category(classifier, df, i, labels, n=3):
seq = df["title"][i]
res_dict = classifier(seq, labels)
return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))
res_list = ray.get([get_meal_category.remote(model_ref, df_ref, i, labels) for i in range(10)])
我有大约 70 个类别(也可以是 20 或 30 个),我希望能够使用 ray 并行处理该过程,但出现错误:
import pandas as pd
import swifter
import json
import ray
from transformers import pipeline
classifier = pipeline("zero-shot-classification")
labels = ["vegetables", "potato", "bell pepper", "tomato", "onion", "carrot", "broccoli",
"lettuce", "cucumber", "celery", "corn", "garlic", "mashrooms", "cabbage", "spinach",
"beans", "cauliflower", "asparagus", "fruits", "bananas", "apples", "strawberries",
"grapes", "oranges", "lemons", "avocados", "peaches", "blueberries", "pineapple",
"cherries", "pears", "mangoe", "berries", "red meat", "beef", "pork", "mutton",
"veal", "lamb", "venison", "goat", "mince", "white meat", "chicken", "turkey",
"duck", "goose", "pheasant", "rabbit", "Processed meat", "sausages", "bacon",
"ham", "hot dogs", "frankfurters", "tinned meat", "salami", "pâtés", "beef jerky",
"chorizo", "pepperoni", "corned beef", "fish", "catfish", "cod", "pangasius", "pollock",
"tilapia", "tuna", "salmon", "seafood", "shrimp", "squid", "mussels", "scallop",
"octopus", "grains", "rice", "wheat", "bulgur", "corn", "oat", "quinoa", "buckwheat",
"meals", "salad", "soup", "steak", "pizza", "pie", "burger", "backery", "bread", "souce",
"pasta", "sandwich", "waffles", "barbecue", "roll", "wings", "ribs", "cookies"]
ray.init()
@ray.remote
def get_meal_category(seq, labels, n=3):
res_dict = classifier(seq, labels)
return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))
res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
其中 merged_df 是一个大数据框,其标签列中包含膳食名称,例如:
['Cappuccino',
'Stove Top Stuffing Mix For Turkey (Kraft)',
'Stove Top Stuffing Mix For Turkey (Kraft)',
'Roasted Dark Turkey Meat',
'Roasted Dark Turkey Meat',
'Roasted Dark Turkey Meat',
'Cappuccino',
'Low Fat 2% Small Curd Cottage Cheese (Daisy)',
'Rice Cereal (Gerber)',
'Oranges']
请指教如何避免ray的错误和并行化分类。
错误:
2021-02-17 16:54:51,689 WARNING worker.py:1107 -- Warning: The remote function __main__.get_meal_category has size 1630925709 when pickled. It will be stored in Redis, which could cause memory issues. This may mean that its definition uses a large array or other object.
---------------------------------------------------------------------------
ConnectionResetError Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
705 for item in command:
--> 706 sendall(self._sock, item)
707 except socket.timeout:
~/.local/lib/python3.8/site-packages/redis/_compat.py in sendall(sock, *args, **kwargs)
8 def sendall(sock, *args, **kwargs):
----> 9 return sock.sendall(*args, **kwargs)
10
ConnectionResetError: [Errno 104] Connection reset by peer
During handling of the above exception, another exception occurred:
ConnectionError Traceback (most recent call last)
<ipython-input-9-1a5345832fba> in <module>
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
<ipython-input-9-1a5345832fba> in <listcomp>(.0)
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])
~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote_proxy(*args, **kwargs)
99 @wraps(function)
100 def _remote_proxy(*args, **kwargs):
--> 101 return self._remote(args=args, kwargs=kwargs)
102
103 self.remote = _remote_proxy
~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote(self, args, kwargs, num_returns, num_cpus, num_gpus, memory, object_store_memory, accelerator_type, resources, max_retries, placement_group, placement_group_bundle_index, placement_group_capture_child_tasks, override_environment_variables, name)
205
206 self._last_export_session_and_job = worker.current_session_and_job
--> 207 worker.function_actor_manager.export(self)
208
209 kwargs = {} if kwargs is None else kwargs
~/.local/lib/python3.8/site-packages/ray/function_manager.py in export(self, remote_function)
142 key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
143 + remote_function._function_descriptor.function_id.binary())
--> 144 self._worker.redis_client.hset(
145 key,
146 mapping={
~/.local/lib/python3.8/site-packages/redis/client.py in hset(self, name, key, value, mapping)
3048 items.extend(pair)
3049
-> 3050 return self.execute_command('HSET', name, *items)
3051
3052 def hsetnx(self, name, key, value):
~/.local/lib/python3.8/site-packages/redis/client.py in execute_command(self, *args, **options)
898 conn = self.connection or pool.get_connection(command_name, **options)
899 try:
--> 900 conn.send_command(*args)
901 return self.parse_response(conn, command_name, **options)
902 except (ConnectionError, TimeoutError) as e:
~/.local/lib/python3.8/site-packages/redis/connection.py in send_command(self, *args, **kwargs)
723 def send_command(self, *args, **kwargs):
724 "Pack and send a command to the Redis server"
--> 725 self.send_packed_command(self.pack_command(*args),
726 check_health=kwargs.get('check_health', True))
727
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
715 errno = e.args[0]
716 errmsg = e.args[1]
--> 717 raise ConnectionError("Error %s while writing to socket. %s." %
718 (errno, errmsg))
719 except BaseException:
ConnectionError: Error 104 while writing to socket. Connection reset by peer.
由于将大型对象发送到 Redis,因此发生此错误。 merged_df
是一个大数据帧,由于您调用 get_meal_category
10 次,Ray 将尝试序列化 merged_df
10 次。相反,如果您将 merged_df
放入 Ray 对象存储中一次,然后传递对该对象的引用,这应该可以工作。
编辑:由于分类器也很大,因此也做类似的事情。
你能试试这样吗:
ray.init()
df_ref = ray.put(merged_df)
model_ref = ray.put(classifier)
@ray.remote
def get_meal_category(classifier, df, i, labels, n=3):
seq = df["title"][i]
res_dict = classifier(seq, labels)
return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))
res_list = ray.get([get_meal_category.remote(model_ref, df_ref, i, labels) for i in range(10)])