Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

下载数据特征顺序混乱 #181

Open
beng003 opened this issue Dec 4, 2024 · 2 comments
Open

下载数据特征顺序混乱 #181

beng003 opened this issue Dec 4, 2024 · 2 comments

Comments

@beng003
Copy link

beng003 commented Dec 4, 2024

Issue Type

Api Usage

Have you searched for existing issues?

Yes

Link to Relevant Documentation

No response

Question Details

自定义组件输出结果特征顺序混乱,可以怎么指定组件输出特征的顺序?
@beng003
Copy link
Author

beng003 commented Dec 4, 2024

自定义组件如下,但是组件输出数据的特征和输入数据的特征顺序不一样,怎么指定输出数据特征顺序?

QQ20241204-193003

max_min_table.csv
input_table.csv

import os
import pandas as pd
import jax.numpy as jnp

from secretflow.component.component import (
    Component,
    IoType,
    CompEvalError,
)

from secretflow.component.data_utils import DistDataType

from secretflow.spec.v1.data_pb2 import (
    DistData,
    IndividualTable,
)

from secretflow.device.device.spu import SPU
from secretflow.device.device.pyu import PYU
from secretflow.device.driver import wait

# Component定义
normalization_comp = Component(
    "normalize_data",
    domain="user",
    version="0.1.0",
    desc="Normalize data using Min-Max Scaling.",
)

normalization_comp.io(
    io_type=IoType.INPUT,
    name="input_table_max_min_path",
    desc="Path to the table containing Min-Max values for each feature (CSV format).",
    types=[DistDataType.INDIVIDUAL_TABLE],
)
normalization_comp.io(
    io_type=IoType.INPUT,
    name="input_table_path",
    desc="Path to the input table (CSV format).",
    types=[DistDataType.INDIVIDUAL_TABLE],
)
normalization_comp.io(
    io_type=IoType.OUTPUT,
    name="output_normalized_table",
    desc="Path to the normalized output table.",
    types=[DistDataType.INDIVIDUAL_TABLE],
)


def df_to_jax_matrix(input_table_data, feature_list=None, output_path=None):
    # 提取特定列的数据
    if feature_list != None:
        extracted_data = input_table_data[feature_list]
    else:
        extracted_data = input_table_data

    # 将数据转换为数值类型(处理非数值数据为 NaN)
    numeric_data = extracted_data.apply(pd.to_numeric, errors='coerce')
    # 转换为 JAX 矩阵
    input_table_jax_matrix = jnp.array(numeric_data.to_numpy())

    if output_path == None:
        # 保存处理后的数据到输出文件
        numeric_data.to_csv(output_path, index=False)

    return input_table_jax_matrix


def normalize_matrix(matrix1, matrix2):
    """
    对第一个矩阵的每一列进行归一化。

    参数:
        matrix1 (jnp.ndarray): 大小为 (m, n) 的 JAX 矩阵。
        matrix2 (jnp.ndarray): 大小为 (2, n) 或 (n, 2) 的 JAX 矩阵。

    返回:
        jnp.ndarray: 归一化后的矩阵,大小为 (m, n)。
    """
    if matrix2.shape[0] != 2 or matrix1.shape[1] != matrix2.shape[1]:
        matrix2 = jnp.transpose(matrix2)

    # 确保 matrix2 是 2*n 的矩阵
    assert matrix2.shape[0] == 2, "matrix2 应该有两行 (2, n)"
    assert matrix1.shape[1] == matrix2.shape[1], "两个矩阵的列数应该一致 (n)"

    # 分解 matrix2 的两行
    min_vals = matrix2[0, :]  # 第一行,最小值
    max_vals = matrix2[1, :]  # 第二行,最大值

    # 计算范围
    range_vals = max_vals - min_vals

    # 防止除零
    range_vals = jnp.where(range_vals == 0, 1e-6, range_vals)

    # 对每列进行归一化
    normalized_matrix = (matrix1 - min_vals) / range_vals

    return normalized_matrix


def assign_columns_with_jax_matrix(df, column_list, jax_matrix):
    """
    将指定列按照 JAX 矩阵的值进行赋值,保证 DataFrame 的列顺序不变。

    参数:
        df (pd.DataFrame): 输入的 Pandas 数据框。
        column_list (list): 需要赋值的列名列表,按顺序指定。
        jax_matrix (jnp.ndarray): 用于赋值的 JAX 矩阵,维度需与列列表长度和行数一致。

    返回:
        pd.DataFrame: 更新后的 DataFrame,列顺序保持不变。
    """
    # 检查列数和矩阵列数是否一致
    assert len(column_list) == jax_matrix.shape[1], "列列表长度必须与 JAX 矩阵列数一致"
    assert len(df) == jax_matrix.shape[0], "DataFrame 的行数必须与 JAX 矩阵行数一致"

    # 将 JAX 矩阵转换为 NumPy 数组
    numpy_matrix = jnp.array(jax_matrix)  # 转换为 NumPy 数组(JAX 格式兼容)

    # 备份原列顺序
    original_columns = df.columns.tolist()

    # 按顺序更新指定列
    for i, col in enumerate(column_list):
        if col in df.columns:
            df[col] = numpy_matrix[:, i]
        else:
            raise ValueError(f"列名 '{col}' 不存在于 DataFrame 中")

    # 确保列顺序不变
    return df[original_columns]


def feature_select(max_min_table):
    return max_min_table["Feature"].tolist()


def save_table_to_csv(table, path):
    table.to_csv(path, index=False)


@normalization_comp.eval_fn
def normalization_eval_fn(
    *,
    ctx,
    input_table_max_min_path,
    input_table_path,
    output_normalized_table,
):
    # 输入和输出路径
    input_table_path_alice = os.path.join(ctx.data_dir, input_table_path.data_refs[0].uri)
    max_min_table_path_bob = os.path.join(ctx.data_dir, input_table_max_min_path.data_refs[0].uri)
    output_path_alice = os.path.join(ctx.data_dir, output_normalized_table)
    
    alice_str = input_table_path.data_refs[0].party
    bob_str = input_table_max_min_path.data_refs[0].party

    # 初始化设备
    # get spu config from ctx
    if ctx.spu_configs is None or len(ctx.spu_configs) == 0:
        raise CompEvalError("spu config is not found.")
    if len(ctx.spu_configs) > 1:
        raise CompEvalError("only support one spu")
    spu_config = next(iter(ctx.spu_configs.values()))
    
    spu_device = SPU(spu_config["cluster_def"], spu_config["link_desc"])
    alice = PYU(alice_str)
    bob = PYU(bob_str)

    input_table_alice = alice(pd.read_csv)(input_table_path_alice)
    max_min_table_bob = bob(pd.read_csv)(max_min_table_path_bob)
    feature_alice = bob(feature_select)(max_min_table_bob).to(alice)

    input_table_matrix_alice = alice(df_to_jax_matrix)(input_table_alice, feature_alice)
    max_min_table_matrix_bob = bob(df_to_jax_matrix)(
        max_min_table_bob, ["MinValues", "MaxValues"]
    )

    input_table_matrix_spu = input_table_matrix_alice.to(spu_device)
    max_min_table_matrix_spu = max_min_table_matrix_bob.to(spu_device)

    out_table_matrix_spu = spu_device(normalize_matrix)(
        input_table_matrix_spu, max_min_table_matrix_spu
    )
    out_table_matrix_alice = out_table_matrix_spu.to(alice)
    out_table_df_alice = alice(assign_columns_with_jax_matrix)(
        input_table_alice, feature_alice, out_table_matrix_alice
    )
    
    wait(alice(save_table_to_csv)(out_table_df_alice, output_path_alice))

    # 生成DistData
    alice_data = DistData(
        name="output_normalized_table",
        type=str(DistDataType.INDIVIDUAL_TABLE),
        data_refs=[
            DistData.DataRef(
                uri=output_normalized_table, party=alice.party, format="csv"
            )
        ],
    )

    alice_meta = IndividualTable( )
    input_table_path.meta.Unpack(alice_meta)
    alice_data.meta.Pack(alice_meta)

    return {"output_normalized_table": alice_data}

@lanyy9527
Copy link

尝试下调整out_table_df_alice的列顺序

out_table_df_alice = alice(lambda df, cols: df[cols]) (
out_table_df_alice, input_table_alice.columns.tolist()
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants