GitHub
1.1k
Created 10 months ago, last commit a month ago
5 contributors
341 commits
Stars added on GitHub, month by month
N/A
N/A
N/A
N/A
12
1
2
3
4
5
6
7
8
9
10
11
2023
2024
Stars added on GitHub, per day, on average
Yesterday
+1
Last week
+0.1
/day
Last month
+0.3
/day
npmPackage on NPM
Monthly downloads on NPM
0
0
0
12
1
2
3
4
5
6
7
8
9
10
11
2023
2024
README

js-torch

PyTorch in JavaScript

  • JS-PyTorch is a Deep Learning JavaScript library built from scratch, to closely follow PyTorch's syntax.
  • This library has GPU support, using GPU.js.
  • If you want to run it yourself, check out the Documentation.
  • Try out the Web Demo!

Note: You can install the package locally with: npm install js-pytorch


Implemented Tensor Operations:
Implemented Deep Learning Layers:

1.Table of Contents

2. Installation

  • On MacOS, Windows, and Ubuntu, you can install the library with npm install js-pytorch.
  • On Windows, if you run into an error, you might need to install the latest version of Visual Studio, including the "Desktop development with C++" workload.
  • To run in the Browser, paste the following tag in the <head> of your HTML file:
<script src="https://cdnjs.cloudflare.com/ajax/libs/js-pytorch/0.7.2/js-pytorch-browser.js"
        integrity="sha512-l22t7GnqXvHBMCBvPUBdFO2TEYxnb1ziCGcDQcpTB2un16IPA4FE5SIZ8bUR+RwoDZGikQkWisO+fhnakXt9rg=="
        crossorigin="anonymous"
        referrerpolicy="no-referrer"></script>
  • After that, you can use JS-PyTorch freely in any <script> in your HTML file:
<head>
    <title>My Project</title>
    <!-- New script goes here -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/js-pytorch/0.7.2/js-pytorch-browser.js" 
            integrity="sha512-l22t7GnqXvHBMCBvPUBdFO2TEYxnb1ziCGcDQcpTB2un16IPA4FE5SIZ8bUR+RwoDZGikQkWisO+fhnakXt9rg=="
            crossorigin="anonymous" 
            referrerpolicy="no-referrer">
    </script>
    <!---->
</head>
<body>
    <script>
        let x = torch.randn([10,5])
        let linear = new torch.nn.Linear(5,1,'gpu',true)
        let z = linear.forward(x)
        console.log(z.data)
    </script>
</body>

3. Running it Yourself

Simple Autograd Example:

// Require the Library if running in node (not necessary in the browser):
const { torch } = require("js-pytorch");

// Pass device as an argument to a Tensor or nn.Module (same as PyTorch):
const device = 'gpu';

// Instantiate Tensors:
let x = torch.randn([8, 4, 5]);
let w = torch.randn([8, 5, 4], true, device);
let b = torch.tensor([0.2, 0.5, 0.1, 0.0], true);

// Make calculations:
let out = torch.matmul(x, w);
out = torch.add(out, b);

// Compute gradients on whole graph:
out.backward();

// Get gradients from specific Tensors:
console.log(w.grad);
console.log(b.grad);

Complex Autograd Example (Transformer):

// Require the Library if running in node (not necessary in the browser):
const { torch } = require("js-pytorch");
const nn = torch.nn;
const optim = torch.optim;

const device = 'gpu';

// Define training hyperparameters:
const vocab_size = 52;
const hidden_size = 32;
const n_timesteps = 16;
const n_heads = 4;
const dropout_p = 0;
const batch_size = 8;

// Create Transformer decoder Module:
class Transformer extends nn.Module {
  constructor(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p, device) {
    super();
    // Instantiate Transformer's Layers:
    this.embed = new nn.Embedding(vocab_size, hidden_size);
    this.pos_embed = new nn.PositionalEmbedding(n_timesteps, hidden_size);
    this.b1 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device);
    this.b2 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device);
    this.ln = new nn.LayerNorm(hidden_size);
    this.linear = new nn.Linear(hidden_size, vocab_size, device);
  }

  forward(x) {
    let z;
    z = torch.add(this.embed.forward(x), this.pos_embed.forward(x));
    z = this.b1.forward(z);
    z = this.b2.forward(z);
    z = this.ln.forward(z);
    z = this.linear.forward(z);
    return z;
  }
}

// Instantiate your custom nn.Module:
const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p, device);

// Define loss function and optimizer:
const loss_func = new nn.CrossEntropyLoss();
const optimizer = new optim.Adam(model.parameters(), (lr = 5e-3), (reg = 0));

// Instantiate sample input and output:
let x = torch.randint(0, vocab_size, [batch_size, n_timesteps, 1]);
let y = torch.randint(0, vocab_size, [batch_size, n_timesteps]);
let loss;

// Training Loop:
for (let i = 0; i < 40; i++) {
  // Forward pass through the Transformer:
  let z = model.forward(x);

  // Get loss:
  loss = loss_func.forward(z, y);

  // Backpropagate the loss using torch.tensor's backward() method:
  loss.backward();

  // Update the weights:
  optimizer.step();

  // Reset the gradients to zero after each training step:
  optimizer.zero_grad();

  // Print loss at every iteration:
  console.log(`Iter ${i} - Loss ${loss.data[0].toFixed(4)}`)
}

Saving and Loading models:

// Instantiate your model:
const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);

// Train the model:
trainModel(model);

// Save model to JSON file:
torch.save(model, 'model.json')

// To load, instantiate placeHolder using the original model's architecture:
const placeHolder = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);

// Load weights into placeHolder:
const newModel = torch.load(placeHolder, 'model.json')

4. Distribution & Devtools

  • Build for Distribution by running npm run build. CJS and ESM modules and index.d.ts will be output in the dist/ folder.
  • Check the Code with ESLint at any time, running npm run lint.
  • Run tests run npm test.
  • Improve Code Formatting with prettier, running npm run prettier.
  • Performance Benchmarks are also included in the tests/benchmarks/ directory. Run all benchmarks with npm run bench and save new benchmarks with npm run bench:update.

5. Future Work

  • This package is not as optimized as PyTorch yet, but I tried making it more interpretable. Efficiency improvements are incoming!
  • Feel free to contribute! Create a merge request to the develop branch, and also feel free to reach out. I'll try to answer as soon as possible.
  • Hope you enjoy!