forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_output_model.py
More file actions
executable file
·44 lines (36 loc) · 1.41 KB
/
multi_output_model.py
File metadata and controls
executable file
·44 lines (36 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import json
import argparse
import torch
class MultiOutputModel(torch.nn.Module):
def __init__(self, hidden_dim, weight_value):
super(MultiOutputModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.linear.weight.data.fill_(weight_value)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, inputs, targets):
losses = []
for x, y in zip(inputs, targets):
hidden_dim = self.linear(x)
loss = self.cross_entropy_loss(hidden_dim, y)
losses.append(loss)
return tuple(losses)
def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
assert len(inputs) == len(targets)
batch_size = model.train_micro_batch_size_per_gpu()
train_data = [
torch.full(size=(total_samples,
hidden_dim),
fill_value=x,
device=device,
dtype=torch.half,
requires_grad=True) for x in inputs
]
train_label = [
torch.empty(total_samples,
device=device,
dtype=torch.long).fill_(y) for y in targets
]
train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader