数据集的官方描述:
class ChickenpoxDatasetLoader(object):
"""A dataset of county level chicken pox cases in Hungary between 2004
and 2014. We made it public during the development of PyTorch Geometric
Temporal. The underlying graph is static - vertices are counties and
edges are neighbourhoods. Vertex features are lagged weekly counts of the
chickenpox cases (we included 4 lags). The target is the weekly number of
cases for the upcoming week (signed integers). Our dataset consist of more
than 500 snapshots (weeks).
"""
翻译:
2004年到2014年期间匈牙利县级水痘病例数据集。
图是静态的——顶点是县;边是邻居;
顶点的特征是滞后的每周计数水痘病例(包含4个滞后);
目标:下一周的水痘的数量;
数据集包含了超过500个时间快照;
官方数据集的加载:
1 def _read_web_data(self): 2 #url = "https://raw.githubusercontent.com/bene?dekrozemberczki/pytorch_geometric_temporal/master/dataset/chickenpox.json" 3 #self._dataset = json.loads(urllib.request.urlopen(url).read())
1 def _get_edges(self): #加载边 2 self._edges = np.array(self._dataset["edges"]).T 3 4 5 def _get_edge_weights(self): #一共有多少条边 6 self._edge_weights = np.ones(self._edges.shape[1]) 7 8 9 def _get_targets_and_features(self): #节点的属性;构建时间序列 10 stacked_target = np.array(self._dataset["FX"]) #m*n维,其中m是时间,n是节点的数量
#时间切片 11 self.features = [stacked_target[i:i+self.lags,:].T for i in range(stacked_target.shape[0]-self.lags)] 12 self.targets = [stacked_target[i+self.lags,:].T for i in range(stacked_target.shape[0]-self.lags)] 13 14 15 def get_dataset(self, lags: int=4) -> StaticGraphTemporalSignal: 16 """Returning the Chickenpox Hungary data iterator. 17 18 Args types: 19 * **lags** *(int)* - The number of time lags. 20 Return types: 21 * **dataset** *(StaticGraphTemporalSignal)* - The Chickenpox Hungary dataset. 22 """ 23 self.lags = lags 24 self._get_edges() 25 self._get_edge_weights() 26 self._get_targets_and_features() 27 dataset = StaticGraphTemporalSignal(self._edges, self._edge_weights, self.features, self.targets) 28 return dataset
数据集解析:
dataset:
原文:https://www.cnblogs.com/xmd-home/p/14814675.html