コンテンツにスキップ

Python/PyTorch

出典: フリー教科書『ウィキブックス(Wikibooks)』
Wikipedia
Wikipedia
ウィキペディアPyTorchの記事があります。

Python/PyTorchの概要

[編集]

PyTorch(パイトーチ)は、Python向けのオープンソースの機械学習フレームワークです。PyTorchは、テンソル計算をベースにした柔軟で高度な数値計算を提供し、ニューラルネットワークをはじめとする機械学習モデルの構築、訓練、デプロイメントを容易に行うことができます。PyTorchは、深層学習研究者やデータサイエンティストに人気があり、アカデミックな研究から実用的なアプリケーションまで幅広い分野で活用されています。

主な機能

[編集]
  1. テンソル計算: PyTorchは高速なテンソル計算をサポートしており、GPUを利用して高速な演算を実現します。テンソルは多次元の配列を表現し、機械学習モデルのデータやパラメータを効率的に扱うことができます。
  2. ニューラルネットワーク: PyTorchはニューラルネットワークの構築と訓練を簡単に行うことができます。モジュール化された設計や自動微分などの機能が備わっており、複雑なネットワークの実装が容易になります。
  3. データローダー: PyTorchはデータローダーを提供しており、大規模なデータセットを効率的に読み込み、バッチ処理を行うことができます。これにより、データの前処理やミニバッチ学習を簡単に実現できます。
  4. モデルの保存と読み込み: PyTorchはモデルのパラメータや構造を保存し、後で読み込むことができる機能を提供します。学習済みモデルの再利用や転移学習などに役立ちます。
  5. インテグレーション: PyTorchは他のPythonライブラリとのシームレスなインテグレーションをサポートしています。NumPySciPyなどのライブラリとの連携が容易に行えます。

Python/PyTorchのコード例

[編集]

コードは未検証です。

テンソルの操作

[編集]
import torch

# データセットの作成
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])

# テンソルの演算
z = x + y

# テンソルの表示
print(z)

ニューラルネットワークの構築

[編集]
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# データセットの作成
inputs = torch.randn(100, 10)
targets = torch.randint(0, 2, (100,))

# データローダーの作成
dataset = TensorDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# ネットワークの定義
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ネットワークのインスタンス化
net = SimpleNet()

# 損失関数と最適化関数の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# ネットワークの訓練
for epoch in range(10):
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

データセットの作成とテンソルの演算

[編集]

上記のコード例では、2つのテンソルを作成して足し算を行い、結果を表示しています。

ニューラルネットワークの構築と訓練

[編集]

上記のコード例では、簡単なニューラルネットワークを定義し、ランダムなデータセットを用いてネットワークの訓練を行っています。データセットは100個の入力データと対応するラベル(0または1)から構成されており、データローダーを使用してバッチごとにデータを取得しています。

これらのコード例を実行することで、PyTorchのテンソル操作やニューラルネットワークの構築と訓練の基本的な方法を学ぶことができます。

実践?

[編集]
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns
import matplotlib.pyplot as plt

# データセットの作成
inputs = torch.randn(100, 10)
targets = torch.randint(0, 2, (100,))

# データローダーの作成
dataset = TensorDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# テストデータの作成
test_inputs = torch.randn(20, 10)
test_targets = torch.randint(0, 2, (20,))
test_dataset = TensorDataset(test_inputs, test_targets)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=False)

# ネットワークの定義
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ネットワークのインスタンス化
net = SimpleNet()

# 損失関数と最適化関数の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# リストにエポックごとの損失を記録
losses = []

# ネットワークの訓練
for epoch in range(300):
    running_loss = 0.0
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # エポックごとの損失を記録
    epoch_loss = running_loss / len(dataloader)
    losses.append(epoch_loss)

    # エポックごとの損失を表示
    print(f"Epoch {epoch+1}, Loss: {epoch_loss}")

# テストデータを用いてネットワークの精度を確認
correct = 0
total = 0
with torch.no_grad():
    for test_inputs, test_targets in test_dataloader:
        test_outputs = net(test_inputs)
        _, predicted = torch.max(test_outputs.data, 1)
        total += test_targets.size(0)
        correct += (predicted == test_targets).sum().item()

print(f"Accuracy on test data: {100 * correct / total}%")

# 損失のグラフを作成
sns.set()
plt.plot(losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

結果

[編集]
PyTorchTraingLossGraph.png
Epoch 1, Loss: 0.7172316431999206
Epoch 2, Loss: 0.7155839323997497
Epoch 3, Loss: 0.7134337782859802
Epoch 4, Loss: 0.7118909060955048
Epoch 5, Loss: 0.7104454457759857
Epoch 6, Loss: 0.7088890850543976
Epoch 7, Loss: 0.7075893402099609
Epoch 8, Loss: 0.7062586903572082
Epoch 9, Loss: 0.7048002004623413
Epoch 10, Loss: 0.7034171283245086
Epoch 11, Loss: 0.7025231957435608
Epoch 12, Loss: 0.7015266597270966
Epoch 13, Loss: 0.7002574920654296
Epoch 14, Loss: 0.698747307062149
Epoch 15, Loss: 0.6977063953876496
Epoch 16, Loss: 0.6966621577739716
Epoch 17, Loss: 0.6960369169712066
Epoch 18, Loss: 0.6948440372943878
Epoch 19, Loss: 0.6939657866954804
Epoch 20, Loss: 0.6930074751377105
Epoch 21, Loss: 0.6920962750911712
Epoch 22, Loss: 0.6911981880664826
Epoch 23, Loss: 0.6903640866279602
Epoch 24, Loss: 0.6899523138999939
Epoch 25, Loss: 0.6889663338661194
Epoch 26, Loss: 0.6880228400230408
Epoch 27, Loss: 0.6870487809181214
Epoch 28, Loss: 0.686165726184845
Epoch 29, Loss: 0.685486626625061
Epoch 30, Loss: 0.6848306059837341
Epoch 31, Loss: 0.6839422225952149
Epoch 32, Loss: 0.6832305550575256
Epoch 33, Loss: 0.6824147522449493
Epoch 34, Loss: 0.6819980144500732
Epoch 35, Loss: 0.6810794234275818
Epoch 36, Loss: 0.6802286505699158
Epoch 37, Loss: 0.6798258125782013
Epoch 38, Loss: 0.6791172385215759
Epoch 39, Loss: 0.6782729268074036
Epoch 40, Loss: 0.6776126682758331
Epoch 41, Loss: 0.6768989622592926
Epoch 42, Loss: 0.6762053728103637
Epoch 43, Loss: 0.6756475090980529
Epoch 44, Loss: 0.6753955066204071
Epoch 45, Loss: 0.6745489716529847
Epoch 46, Loss: 0.6739313542842865
Epoch 47, Loss: 0.673326188325882
Epoch 48, Loss: 0.6729915618896485
Epoch 49, Loss: 0.6722120344638824
Epoch 50, Loss: 0.6718266606330872
Epoch 51, Loss: 0.6716600954532623
Epoch 52, Loss: 0.6705677092075348
Epoch 53, Loss: 0.670308530330658
Epoch 54, Loss: 0.6694739460945129
Epoch 55, Loss: 0.6696305751800538
Epoch 56, Loss: 0.6684424042701721
Epoch 57, Loss: 0.6680074572563172
Epoch 58, Loss: 0.66744304895401
Epoch 59, Loss: 0.6669774293899536
Epoch 60, Loss: 0.6665449023246766
Epoch 61, Loss: 0.6661204695701599
Epoch 62, Loss: 0.6657149434089661
Epoch 63, Loss: 0.665413784980774
Epoch 64, Loss: 0.6644633650779724
Epoch 65, Loss: 0.6640498042106628
Epoch 66, Loss: 0.6635482728481292
Epoch 67, Loss: 0.6631396770477295
Epoch 68, Loss: 0.6626389265060425
Epoch 69, Loss: 0.6620482861995697
Epoch 70, Loss: 0.6617119312286377
Epoch 71, Loss: 0.66120285987854
Epoch 72, Loss: 0.6607514917850494
Epoch 73, Loss: 0.6605069577693939
Epoch 74, Loss: 0.6601471483707428
Epoch 75, Loss: 0.6598217964172364
Epoch 76, Loss: 0.6590731561183929
Epoch 77, Loss: 0.6587798476219178
Epoch 78, Loss: 0.6581548631191254
Epoch 79, Loss: 0.6583406329154968
Epoch 80, Loss: 0.6577327191829682
Epoch 81, Loss: 0.6571271896362305
Epoch 82, Loss: 0.656671267747879
Epoch 83, Loss: 0.6562894463539124
Epoch 84, Loss: 0.655829232931137
Epoch 85, Loss: 0.6553720891475677
Epoch 86, Loss: 0.6552034676074981
Epoch 87, Loss: 0.6547967314720153
Epoch 88, Loss: 0.6542360842227936
Epoch 89, Loss: 0.6536154091358185
Epoch 90, Loss: 0.6535664677619935
Epoch 91, Loss: 0.6530918955802918
Epoch 92, Loss: 0.6530490577220917
Epoch 93, Loss: 0.6518986225128174
Epoch 94, Loss: 0.6515590190887451
Epoch 95, Loss: 0.6512846767902374
Epoch 96, Loss: 0.6512621939182281
Epoch 97, Loss: 0.6505548655986786
Epoch 98, Loss: 0.6501051962375641
Epoch 99, Loss: 0.6494227886199951
Epoch 100, Loss: 0.649199378490448
Epoch 101, Loss: 0.6486238539218903
Epoch 102, Loss: 0.6482992887496948
Epoch 103, Loss: 0.6479936063289642
Epoch 104, Loss: 0.6475347340106964
Epoch 105, Loss: 0.647095674276352
Epoch 106, Loss: 0.6466729760169982
Epoch 107, Loss: 0.6463548004627228
Epoch 108, Loss: 0.6456947088241577
Epoch 109, Loss: 0.6452789962291717
Epoch 110, Loss: 0.6445926547050476
Epoch 111, Loss: 0.644685173034668
Epoch 112, Loss: 0.6439492046833039
Epoch 113, Loss: 0.6432851135730744
Epoch 114, Loss: 0.6427569389343262
Epoch 115, Loss: 0.6423129439353943
Epoch 116, Loss: 0.6422940135002136
Epoch 117, Loss: 0.6415979444980622
Epoch 118, Loss: 0.6415235042572022
Epoch 119, Loss: 0.640673142671585
Epoch 120, Loss: 0.640322333574295
Epoch 121, Loss: 0.6399925887584687
Epoch 122, Loss: 0.6392323136329651
Epoch 123, Loss: 0.6386380940675735
Epoch 124, Loss: 0.6380882740020752
Epoch 125, Loss: 0.6379547357559204
Epoch 126, Loss: 0.6373067677021027
Epoch 127, Loss: 0.6371121287345887
Epoch 128, Loss: 0.6361018240451812
Epoch 129, Loss: 0.6354252219200134
Epoch 130, Loss: 0.6350532352924347
Epoch 131, Loss: 0.6341704726219177
Epoch 132, Loss: 0.6337720513343811
Epoch 133, Loss: 0.6329683005809784
Epoch 134, Loss: 0.6322911262512207
Epoch 135, Loss: 0.6322241604328156
Epoch 136, Loss: 0.6316525161266326
Epoch 137, Loss: 0.6310757458209991
Epoch 138, Loss: 0.6307730972766876
Epoch 139, Loss: 0.6301729202270507
Epoch 140, Loss: 0.62952039539814
Epoch 141, Loss: 0.6292171061038971
Epoch 142, Loss: 0.6288813889026642
Epoch 143, Loss: 0.6283952474594117
Epoch 144, Loss: 0.627724552154541
Epoch 145, Loss: 0.6275863885879517
Epoch 146, Loss: 0.6266201615333558
Epoch 147, Loss: 0.6264863908290863
Epoch 148, Loss: 0.6258289456367493
Epoch 149, Loss: 0.6254477679729462
Epoch 150, Loss: 0.6248232185840606
Epoch 151, Loss: 0.6242233693599701
Epoch 152, Loss: 0.6238783121109008
Epoch 153, Loss: 0.6234713077545166
Epoch 154, Loss: 0.6228610515594483
Epoch 155, Loss: 0.6227936685085297
Epoch 156, Loss: 0.6218818128108978
Epoch 157, Loss: 0.6216851651668549
Epoch 158, Loss: 0.6212225258350372
Epoch 159, Loss: 0.6207438766956329
Epoch 160, Loss: 0.6202704071998596
Epoch 161, Loss: 0.6198987662792206
Epoch 162, Loss: 0.6194111585617066
Epoch 163, Loss: 0.6188644766807556
Epoch 164, Loss: 0.6188222765922546
Epoch 165, Loss: 0.6180778264999389
Epoch 166, Loss: 0.6178474009037018
Epoch 167, Loss: 0.6173537015914917
Epoch 168, Loss: 0.6169044673442841
Epoch 169, Loss: 0.6166277647018432
Epoch 170, Loss: 0.6163138329982758
Epoch 171, Loss: 0.6159413099288941
Epoch 172, Loss: 0.615620756149292
Epoch 173, Loss: 0.6154781699180603
Epoch 174, Loss: 0.6147140145301819
Epoch 175, Loss: 0.6144913852214813
Epoch 176, Loss: 0.6138638436794281
Epoch 177, Loss: 0.6137386798858643
Epoch 178, Loss: 0.613091516494751
Epoch 179, Loss: 0.6130679965019226
Epoch 180, Loss: 0.6124846041202545
Epoch 181, Loss: 0.611948874592781
Epoch 182, Loss: 0.6114454686641693
Epoch 183, Loss: 0.6110112190246582
Epoch 184, Loss: 0.6105754792690277
Epoch 185, Loss: 0.6102881073951721
Epoch 186, Loss: 0.6100647509098053
Epoch 187, Loss: 0.6093645006418228
Epoch 188, Loss: 0.6090194493532181
Epoch 189, Loss: 0.6083994507789612
Epoch 190, Loss: 0.6081330448389053
Epoch 191, Loss: 0.60744249522686
Epoch 192, Loss: 0.6070296436548233
Epoch 193, Loss: 0.6067060708999634
Epoch 194, Loss: 0.6064476937055587
Epoch 195, Loss: 0.605833238363266
Epoch 196, Loss: 0.6054053664207458
Epoch 197, Loss: 0.6047285258769989
Epoch 198, Loss: 0.6042890906333923
Epoch 199, Loss: 0.604012879729271
Epoch 200, Loss: 0.6034434169530869
Epoch 201, Loss: 0.6030950725078583
Epoch 202, Loss: 0.6027598440647125
Epoch 203, Loss: 0.6020280301570893
Epoch 204, Loss: 0.6017543882131576
Epoch 205, Loss: 0.6019152939319611
Epoch 206, Loss: 0.6010530740022659
Epoch 207, Loss: 0.6006012290716172
Epoch 208, Loss: 0.6000461339950561
Epoch 209, Loss: 0.5997589021921158
Epoch 210, Loss: 0.5994018077850342
Epoch 211, Loss: 0.5992134541273118
Epoch 212, Loss: 0.5990976691246033
Epoch 213, Loss: 0.5982634246349334
Epoch 214, Loss: 0.5980139166116715
Epoch 215, Loss: 0.5975212603807449
Epoch 216, Loss: 0.5971673429012299
Epoch 217, Loss: 0.5973487019538879
Epoch 218, Loss: 0.5964401572942734
Epoch 219, Loss: 0.5963750749826431
Epoch 220, Loss: 0.5960692882537841
Epoch 221, Loss: 0.5956205189228058
Epoch 222, Loss: 0.5949371218681335
Epoch 223, Loss: 0.5950301826000214
Epoch 224, Loss: 0.5942545533180237
Epoch 225, Loss: 0.5941599071025848
Epoch 226, Loss: 0.5939353793859482
Epoch 227, Loss: 0.5931886374950409
Epoch 228, Loss: 0.5929647147655487
Epoch 229, Loss: 0.5927237868309021
Epoch 230, Loss: 0.5925554037094116
Epoch 231, Loss: 0.5920345067977906
Epoch 232, Loss: 0.592002984881401
Epoch 233, Loss: 0.5916295528411866
Epoch 234, Loss: 0.5912656903266906
Epoch 235, Loss: 0.5905644118785858
Epoch 236, Loss: 0.5912041962146759
Epoch 237, Loss: 0.5899196922779083
Epoch 238, Loss: 0.5898394078016281
Epoch 239, Loss: 0.5894883751869202
Epoch 240, Loss: 0.5890681833028794
Epoch 241, Loss: 0.5890004992485046
Epoch 242, Loss: 0.5883205324411392
Epoch 243, Loss: 0.5882011085748673
Epoch 244, Loss: 0.5878109097480774
Epoch 245, Loss: 0.5871420562267303
Epoch 246, Loss: 0.5871152281761169
Epoch 247, Loss: 0.5869166493415833
Epoch 248, Loss: 0.5864647209644318
Epoch 249, Loss: 0.5866314232349396
Epoch 250, Loss: 0.5854829847812653
Epoch 251, Loss: 0.5855059385299682
Epoch 252, Loss: 0.5851253151893616
Epoch 253, Loss: 0.5850923120975494
Epoch 254, Loss: 0.5846110463142395
Epoch 255, Loss: 0.5840506374835968
Epoch 256, Loss: 0.5840651422739029
Epoch 257, Loss: 0.5841655641794204
Epoch 258, Loss: 0.5833052456378937
Epoch 259, Loss: 0.5829112023115158
Epoch 260, Loss: 0.582865571975708
Epoch 261, Loss: 0.5821620136499405
Epoch 262, Loss: 0.5821201235055924
Epoch 263, Loss: 0.5812940508127212
Epoch 264, Loss: 0.5818515568971634
Epoch 265, Loss: 0.5809178173542022
Epoch 266, Loss: 0.5801565706729889
Epoch 267, Loss: 0.5798440545797348
Epoch 268, Loss: 0.5796185880899429
Epoch 269, Loss: 0.5792011708021164
Epoch 270, Loss: 0.579231470823288
Epoch 271, Loss: 0.5786095678806304
Epoch 272, Loss: 0.5784413158893585
Epoch 273, Loss: 0.5780120074748993
Epoch 274, Loss: 0.5777808219194412
Epoch 275, Loss: 0.5775089353322983
Epoch 276, Loss: 0.5769561767578125
Epoch 277, Loss: 0.5769323170185089
Epoch 278, Loss: 0.5762925386428833
Epoch 279, Loss: 0.5761769980192184
Epoch 280, Loss: 0.5752924859523774
Epoch 281, Loss: 0.5752902835607528
Epoch 282, Loss: 0.5750128507614136
Epoch 283, Loss: 0.5749126017093659
Epoch 284, Loss: 0.5740054368972778
Epoch 285, Loss: 0.5739078521728516
Epoch 286, Loss: 0.5734700471162796
Epoch 287, Loss: 0.5732420295476913
Epoch 288, Loss: 0.5730233132839203
Epoch 289, Loss: 0.572459477186203
Epoch 290, Loss: 0.5721667975187301
Epoch 291, Loss: 0.5719420671463012
Epoch 292, Loss: 0.5714848816394806
Epoch 293, Loss: 0.5711503505706788
Epoch 294, Loss: 0.5709364682435989
Epoch 295, Loss: 0.5710601389408112
Epoch 296, Loss: 0.5706405311822891
Epoch 297, Loss: 0.5701032459735871
Epoch 298, Loss: 0.5694433361291885
Epoch 299, Loss: 0.5692947596311569
Epoch 300, Loss: 0.5687483668327331
Accuracy on test data: 50.0%

実践??

[編集]
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns
import matplotlib.pyplot as plt

# データセットの作成
inputs_train = torch.randn(80, 10)  # 80未満のデータ
inputs_test = torch.randn(120, 10)  # 80以上のデータ
targets_train = torch.zeros(80)  # 80未満は不合格(0)
targets_test = torch.ones(120)  # 80以上は合格(1)

# データローダーの作成
train_dataset = TensorDataset(inputs_train, targets_train)
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)

test_dataset = TensorDataset(inputs_test, targets_test)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=False)

# ネットワークの定義(同じネットワークを使用)
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ネットワークのインスタンス化
net = SimpleNet()

# 損失関数と最適化関数の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# リストにエポックごとの損失を記録
losses = []

# ネットワークの訓練
for epoch in range(300):
    running_loss = 0.0
    for inputs, targets in train_dataloader:
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets.long())  # CrossEntropyLossはlong型のターゲットを受け取る
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # エポックごとの損失を記録
    epoch_loss = running_loss / len(train_dataloader)
    losses.append(epoch_loss)

    # エポックごとの損失を表示
    print(f"Epoch {epoch+1}, Loss: {epoch_loss}")

# モデルを保存
torch.save(net.state_dict(), 'model.pth')

# テストデータを用いてネットワークの精度を確認
correct = 0
total = 0
with torch.no_grad():
    for test_inputs, test_targets in test_dataloader:
        test_outputs = net(test_inputs)
        _, predicted = torch.max(test_outputs.data, 1)
        total += test_targets.size(0)
        correct += (predicted == test_targets).sum().item()

print(f"Accuracy on test data: {100 * correct / total}%")

# 損失のグラフを作成
sns.set()
plt.plot(losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

このコードはモデルの保存も実行

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# ネットワークの定義(同じネットワークを使用)
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__() 
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 新しいデータを指定して予測を行う関数
def predict_new_data(model, new_data):
    dataset = TensorDataset(new_data)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=False)

    predictions = []
    with torch.no_grad():
        for inputs in dataloader:
            outputs = model(inputs[0])
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.tolist())
    
    return predictions

# ネットワークのインスタンス化
net = SimpleNet()

# 保存したモデルのパラメータを読み込む
net.load_state_dict(torch.load('model.pth'))

# 新しいデータの作成
new_inputs = torch.randn(50, 10)  # 50個の新しいデータ

# 新しいデータを指定してネットワークの推論を実行
predictions = predict_new_data(net, new_inputs)

# 合否のヒストグラムを作成して表示
plt.hist(predictions, bins=2, rwidth=0.8, align='left')
plt.xticks([0, 1], ['不合格', '合格'])
plt.xlabel('結果')
plt.ylabel('頻度')
plt.title('合否のヒストグラム')
plt.show()

PyTorchのインストール方法

[編集]

PyTorchはpipコマンドを使用して簡単にインストールすることができます。以下のコマンドを実行してください:

pip install torch