论文辅助笔记:TEMPO 之 dataset.py

0 导入库

import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from .utils import StandardScaler, decompose
from .features import time_features

1 Dataset_ETT_hour

1.1 构造函数

class Dataset_ETT_hour(Dataset):
    def __init__(
        self,
        root_path,
        flag="train",
        size=None,
        features="S",
        data_path="ETTh1.csv",
        target="OT",
        scale=True,
        inverse=False,
        timeenc=0,
        freq="h",
        cols=None,
        period=24,
    ):

        if size == None:
            self.seq_len = 24 * 4 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.pred_len = size[1]
        #输入sequence和输出sequence的长度


        
        assert flag in ["train", "test", "val"]
        type_map = {"train": 0, "val": 1, "test": 2}
        self.set_type = type_map[flag]
        '''
        指定数据集的用途,可以是 "train"、"test" 或 "val",分别对应训练集、测试集和验证集
        '''

        self.features = features
        #指定数据集包含的特征类型,默认为 "S",表示单一特征

        self.target = target
        #指定预测的目标特征


        self.scale = scale
        #一个布尔值,用于确定数据是否需要归一化处理

        self.inverse = inverse
        #一个布尔值,用于决定是否进行逆变换

        self.timeenc = timeenc
        #用于确定是否对时间进行编码【原始模样 or -0.5~0.5区间】

        self.freq = freq
        #定义时间序列的频率,如 "h" 表示小时级别的频率

        self.period = period
        #定义时间序列的周期,默认为 24

        self.root_path = root_path
        self.data_path = data_path

        self.__read_data__()
        #用于读取并初始化数据集

1.2 __read_data__

def __read_data__(self):
        self.scaler = StandardScaler()
        #初始化一个 StandardScaler 对象,用于数据的标准化处理

        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
        #读取数据集文件,将其存储为 DataFrame 对象 df_raw

        border1s = [
            0,
            12 * 30 * 24 - self.seq_len,
            12 * 30 * 24 + 4 * 30 * 24 - self.seq_len,
        ]
        #定义了三个区间的起始位置,分别对应训练集、验证集和测试集
        border2s = [
            12 * 30 * 24,
            12 * 30 * 24 + 4 * 30 * 24,
            12 * 30 * 24 + 8 * 30 * 24,
        ]
        #定义了每个区间的结束位置



        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]
        '''
        通过 self.set_type 确定当前数据集类型

        并从 border1s 和 border2s 中获取对应的起始和结束位置 border1 和 border2
        '''


        if self.features == "M" or self.features == "MS":
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == "S":
            df_data = df_raw[[self.target]]
        '''
        选择特征数据:
            多特征 "M" 或 "MS":选择所有数据列,除去日期列。
            单一特征 "S":只选择目标特征列(由 self.target 指定)。
        '''

        if self.scale:
            train_data = df_data[border1s[0] : border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values
        '''
        如果 self.scale 为 True,则执行数据归一化:
            train_data:选择训练集的数据,用于拟合 self.scaler。
            data:对整个 df_data 进行转换。
        '''

        df_stamp = df_raw[["date"]][border1:border2]
        df_stamp["date"] = pd.to_datetime(df_stamp.date)
        data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)
        '''
        时间特征处理:
            提取日期列 df_stamp,并将其转换为时间特征:
            pd.to_datetime:将日期转换为 datetime 对象。
            time_features:用于生成时间特征。
        '''

        self.data_x = data[border1:border2]
        if self.inverse:
            self.data_y = df_data.values[border1:border2]
        else:
            self.data_y = data[border1:border2]
        self.data_stamp = data_stamp
        '''
        将转换后的数据和时间特征赋值给 self.data_x、self.data_y 和 self.data_stamp:
            self.data_x 取 data 中的对应区间数据。
            self.data_y 根据 self.inverse 决定是从 data 还是 df_data 中获取。
            self.data_stamp 取生成的时间特征。
        '''

1.3 __getitem__

def __getitem__(self, index):
        s_begin = index
        #设置序列的起始点
        s_end = s_begin + self.seq_len
        #计算序列的结束点
        r_begin = s_end
        #设置预测序列的起始点
        r_end = r_begin + self.pred_len
        #计算预测序列的结束点

        seq_x = self.data_x[s_begin:s_end]
        #从 data_x 中提取序列部分
        seq_y = self.data_y[r_begin:r_end]
        # 从 data_y 中提取预测部分[ground-truth]


        x = torch.tensor(seq_x, dtype=torch.float).transpose(1, 0)  # [1, seq_len]
        y = torch.tensor(seq_y, dtype=torch.float).transpose(1, 0)  # [1, pred_len]

        (trend, seasonal, residual) = decompose(x, period=self.period)
        #对序列 x 进行时间序列分解,返回趋势、季节性和残差三部分
        components = torch.cat((trend, seasonal, residual), dim=0)  # [3, seq_len]
        #将分解后的三部分按 0 维(纵向)拼接,形成一个包含三种特征的张量

        return components, y

1.3__len__

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

1.4  inverse_transform

将数据进行逆转换,还原到原始尺度

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

2 Dataset_ETT_minute

基本上和hour 的一样,几个地方不一样:

  • __init__
    • data_path="ETTm1.csv",
    • freq="t",
    • period: int = 60,
  • __read_data__
    • border1s = [
                  0,
                  12 * 30 * 24 * 4 - self.seq_len,
                  12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len,
              ]
      border2s = [
                  12 * 30 * 24 * 4,
                  12 * 30 * 24 * 4 + 4 * 30 * 24 * 4,
                  12 * 30 * 24 * 4 + 8 * 30 * 24 * 4,
              ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/589531.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

加州大学欧文分校英语中级语法专项课程02:Adjectives and Adjective Clauses 学习笔记

Adjectives and Adjective Clauses course certificate 本文是 https://www.coursera.org/learn/adjective-clauses 这门课的学习笔记。 文章目录 Adjectives and Adjective ClausesWeek 01: Adjectives and Adjective PhrasesLearning Objectives Adjectives Introduction Le…

解码Starknet Verifier:深入逆向工程之旅

1. 引言 Sandstorm为: 能提交独立proof给StarkWare的Ethereum Verifier,的首个开源的STARK prover。 开源代码见: https://github.com/andrewmilson/sandstorm(Rust) L2Beat 提供了以太坊上Starknet的合约架构图&…

单链表经典算法

一,移除链表元素 思路一 遍历数组,如果遇到链表中的元素等于val的节点就执行删除操作 typedef struct ListNode ListNode;struct ListNode* removeElements(struct ListNode* head, int val) {if(headNULL){return NULL;} ListNode*pnewhead(ListNode*)m…

14.集合、常见的数据结构

集合 概念 Java中的集合就是一个容器,用来存放Java对象。 集合在存放对象的时候,不同的容器,存放的方法实现是不一样的, Java中将这些不同实现的容器,往上抽取就形成了Java的集合体系。 Java集合中的根接口&#x…

MVC和DDD的贫血和充血模型对比

文章目录 架构区别MVC三层架构DDD四层架构 贫血模型代码示例 充血模型代码示例 架构区别 MVC三层架构 MVC三层架构是软件工程中的一种设计模式,它将软件系统分为 模型(Model)、视图(View)和控制器(Contro…

前端工程化03-贝壳找房项目案例JavaScript常用的js库

4、项目实战(贝壳找房) 这个项目包含,基本的ajax请求调用,内容的渲染,防抖节流的基本使用,ajax请求工具类的封装 4.1、项目的接口文档 下述接口文档: 简述内容baseURL:http://123.207.32.32…

SQL——高级教程【菜鸟教程】

SQL连接 左连接:SQL LEFT JOIN 关键字 左表相当于主表,不管与右表匹不匹配都会显示所有数据 右表就只会显示和左表匹配的内容。 //例显示:左表的name,有表的总数,时间 SELECT Websites.name, access_log.count, acc…

【机器学习-15】决策树(Decision Tree,DT)算法介绍:原理与案例实现

前言 决策树算法是机器学习领域中的一种重要分类方法,它通过树状结构来进行决策分析。决策树凭借其直观易懂、易于解释的特点,在分类问题中得到了广泛的应用。本文将介绍决策树的基本原理,包括熵和信息熵的相关概念,以及几种经典的…

上位机开发PyQt5(二)【单行输入框、多行输入框、按钮的信号和槽】

目录 一、单行输入框QLineEdit QLineEdit的方法: 二、多行输入框QTextEdit QTextEdit的方法 三、按钮QPushButton 四、按钮的信号与槽 信号与槽简介: 信号和槽绑定: 使用PyQt的槽函数 一、单行输入框QLineEdit QLineEdit控件可以输入…

双向链表专题

文章目录 目录1. 双向链表的结构2. 双向链表的实现3. 顺序表和双向链表的优缺点分析 目录 双向链表的结构双向链表的实现顺序表和双向链表的优缺点分析 1. 双向链表的结构 注意: 这⾥的“带头”跟前面我们说的“头节点”是两个概念,带头链表里的头节点…

Redis 实战1

SDS Redis 只会使用 C 字符串作为字面量, 在大多数情况下, Redis 使用 SDS (Simple Dynamic String,简单动态字符串)作为字符串表示。 比起 C 字符串, SDS 具有以下优点: 常数复杂度获取字符串…

JavaEE >> Spring MVC(2)

接上文 本文介绍如何使用 Spring Boot/MVC 项目将程序执行业务逻辑之后的结果返回给用户,以及一些相关内容进行分析解释。 返回静态页面 要返回一个静态页面,首先需要在 resource 中的 static 目录下面创建一个静态页面,下面将创建一个静态…

[嵌入式系统-53]:嵌入式系统集成开发环境大全 ( IAR Embedded Workbench(通用)、MDK(ARM)比较 )

目录 一、嵌入式系统集成开发环境分类 二、由MCU芯片厂家提供的集成开发工具 三、由嵌入式操作提供的集成开发工具 四、由第三方工具厂家提供的集成开发工具 五、开发工具的整合 5.1 Keil MDK for ARM 5.2 IAR Embedded Workbench(通用)、MDK&…

01.本地工作目录、暂存区、本地仓库三者的工作关系

1.持续集成 1.持续集成CI 让产品可以快速迭代,同时还能保持高质量。 简化工作 2.持续交付 交付 3.持续部署 部署 4.持续集成实现的思路 gitjenkins 5.版本控制系统 1.版本控制系统概述2.Git基本概述3.Git基本命令 2.本地工作目录、暂存区、本地仓库三者的工作关系…

抖音评论区精准获客自动化获客释放双手

挺好用的,评论区自动化快速获客,如果手动点引流涨,那就很耗费时间了,不是吗? 网盘自动获取 链接:https://pan.baidu.com/s/1lpzKPim76qettahxvxtjaQ?pwd0b8x 提取码:0b8x

leetcode84柱状图中最大的矩形

题解&#xff1a; - 力扣&#xff08;LeetCode&#xff09; class Solution {public int largestRectangleArea(int[] heights) {Stack<Integer> stack new Stack<>();int maxArea Integer.MIN_VALUE;for(int i 0;i < heights.length;i){int curHeight hei…

YOLOV8添加SKATTENTION

修改ultralytics.nn.modules._init_.py https://zhuanlan.zhihu.com/p/474599120?utm_sourcezhihu&utm id0 https://blog.csdn.net/weixin 42878111/article/details/136060087 https://blog.csdn.net/gg 51511878/aricle/details/138002223 . 最后输出层不一样。

JAVA面试之MQ

如何保证消息的可靠传输&#xff1f;如果消息丢了怎么办 数据的丢失问题&#xff0c;可能出现在生产者、MQ、消费者中。 &#xff08;1&#xff09;生产者发送消息时丢失&#xff1a; ①生产者发送消息时连接MQ失败 ②生产者发送消息到达MQ后未找到Exchange(交换机) ③生产者发…

一对一WebRTC视频通话系列(一)—— 创建页面并显示摄像头画面

本系列博客主要记录WebRtc实现过程中的一些重点&#xff0c;代码全部进行了注释&#xff0c;便于理解WebRTC整体实现。 一、创建html页面 简单添加input、button、video控件的布局。 <html><head><title>WebRTC demo</title></head><h1>…

单片机编程实例400例大全(100-200)

今天继续分享单片机编程实例第100-200例。 今天的实例会比前面100复杂一些&#xff0c;我大概看了下&#xff0c;很多都具备实际产品的参考价值。 今天继续分享单片机编程实例第100-200例。 今天的实例会比前面100复杂一些&#xff0c;我大概看了下&#xff0c;很多都具备实际…
最新文章