メインコンテンツまでスキップ

PyTorch

Open In Colab

PyTorchは、特に研究者の間で、Pythonにおけるディープラーニングのための最も人気のあるフレームワークの一つです。W&Bは、PyTorchに対して、勾配のログ取得からCPUとGPUのコードプロファイリングまで、最高レベルのサポートを提供しています。

備考

私たちのインテグレーションをColabノートブックでお試しいただくか(ビデオ解説付き)、スクリプトが含まれているexampleリポジトリをご覧ください。これには、Fashion MNISTHyperbandを使用したハイパーパラメータ最適化に関するものが含まれています。さらに、それが生成するW&B ダッシュボードもご覧いただけます。

wandb.watchを使った勾配の記録

自動的に勾配をログに記録するには、wandb.watchを呼び出して、PyTorchモデルを渡します。

import wandb

wandb.init(config=args)

model = ... # モデルの設定

# Magic
wandb.watch(model, log_freq=100)

model.train()
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
wandb.log({"loss": loss})

同じスクリプト内で複数のモデルをトラッキングする必要がある場合は、それぞれのモデルでwandb.watchを呼び出すことができます。この関数のリファレンスドキュメントはこちらです。

注意

勾配、メトリクス、およびグラフは、forwardパスとbackwardパスの後にwandb.logが呼び出されるまでログに記録されません。

画像やメディアのログ記録

PyTorchのTensorsに画像データを渡すと、wandb.Imagetorchvisionのユーティリティが自動的に画像に変換します。

images_t = ...  # PyTorch Tensorsとして画像を生成または読み込み
wandb.log({"examples": [wandb.Image(im) for im in images_t]})

PyTorchや他のフレームワークでW&Bにリッチメディアをログに記録する方法については、media logging guideを参照してください。

また、メディアに情報を添えて記録したい場合(モデルの予測や派生したメトリクスなど)、wandb.Tableを使用してください。

my_table = wandb.Table()

my_table.add_column("image", images_t)
my_table.add_column("label", labels)
my_table.add_column("class_prediction", predictions_t)

# W&BにTableをログする
wandb.log({"mnist_predictions": my_table})

上のコードでこのようなテーブルが生成されます。このモデルは良さそうです! The code above generates a table like this one. This model's looking good!

データセットとモデルのログや可視化について詳しくは、W&B Tables のガイドを参照してください。

PyTorch コードのプロファイリング

W&Bダッシュボード内でPyTorchコード実行の詳細トレースを表示。

W&Bは PyTorch KinetoTensorboardプラグインと直接統合して、PyTorch コードのプロファイリングツール、CPU と GPU 通信の詳細の検査、ボトルネックの特定や最適化のためのツールを提供しています。

profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
schedule=schedule, # スケジューリングの詳細についてはプロファイラのドキュメントを参照してください
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
with_stack=True,
)

with profiler:
... # ここでプロファイリングしたいコードを実行
# 詳細な使用方法については、プロファイラのドキュメントを参照してください

# wandbアーティファクトを作成
profile_art = wandb.Artifact("trace", type="profile")
# pt.trace.jsonファイルをアーティファクトに追加
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# アーティファクトをログに記録
profile_art.save()

このColabで実例コードを確認して実行してください。

:::注意

インタラクティブなトレースビューアツールは、Chrome Trace Viewerをベースにしており、Chromeブラウザで最適に動作します。

:::

Was this page helpful?👍👎