Skip to content

Commit

Permalink
use public Replicate API to get model metadata
Browse files Browse the repository at this point in the history
and test it. getting that mocking setup right was the hard part 😅
  • Loading branch information
zeke committed Feb 7, 2024
1 parent ce62f19 commit 32f14f1
Show file tree
Hide file tree
Showing 6 changed files with 1,096 additions and 23 deletions.
28 changes: 18 additions & 10 deletions index.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const envFile = path.join(targetDir, '.env')

if (process.env.REPLICATE_API_TOKEN) {
fs.writeFileSync(envFile, `REPLICATE_API_TOKEN=${process.env.REPLICATE_API_TOKEN}`)
console.log('Adding API token to .env file')
console.log(`Adding API token ${process.env.REPLICATE_API_TOKEN.slice(0, 5)} to .env file...`)
} else {
console.log('API token not found in environment.')
const rl = readline.createInterface({ input: process.stdin, output: process.stdout })
Expand All @@ -47,31 +47,39 @@ if (process.env.REPLICATE_API_TOKEN) {
if (answer.toLowerCase() === 'y' || answer === '') {
await open('https://replicate.com/account')
const token = readlineSync.question('Paste your API token here: ', { hideEchoBack: true })

// Add the pasted token to the user's local .env file for when they run their project
fs.writeFileSync(envFile, `REPLICATE_API_TOKEN=${token}`)
console.log('API token written to .env file')

// Also add the pasted token to THIS script's environment, so we can use it to make Replicate API calls
process.env.REPLICATE_API_TOKEN = token

console.log(`API token ${process.env.REPLICATE_API_TOKEN.slice(0, 5)} written to .env file`)
}
}

// Check use-provided API token looks legit before proceeding
if (!process.env.REPLICATE_API_TOKEN.startsWith('r8_')) {
console.log('Invalid API token:', process.env.REPLICATE_API_TOKEN)
// process.exit(1)
}

console.log('Setting package name...')
execSync(`npm pkg set name=${args.packageName}`, { cwd: targetDir, stdio: 'ignore' })

console.log('Installing dependencies...')
execSync('npm install', { cwd: targetDir, stdio: 'ignore' })

let model
try {
model = await getModel(args.model)
} catch (e) {
console.error('Model not found:', args.model)
process.exit()
}
console.log('Fetching model metadata using Replicate API...')
const model = await getModel(args.model)

// If user has provided a model version, use it. Otherwise, use the latest version
const modelVersionRegexp = /.*:[a-fA-F0-9]{64}$/
const modelNameWithVersion = args.model.match(modelVersionRegexp) ? args.model : getModelNameWithVersion(model)

const inputs = getModelInputs(model)

console.log('Adding model data and inputs to index.js...')
const indexFile = path.join(targetDir, 'index.js')
const indexFileContents = fs.readFileSync(indexFile, 'utf8')
const newContents = indexFileContents
Expand All @@ -82,7 +90,7 @@ fs.writeFileSync(indexFile, newContents)
console.log('App created successfully!')

if (args['run-after-setup']) {
console.log(`Running command: \`node ${args.packageName}/index.js\`\n\n`)
console.log(`Running command: \`node ${args.packageName}/index.js\``)
execSync('node index.js', { cwd: targetDir, stdio: 'inherit' })
} else {
console.log('To run your app, execute the following command:')
Expand Down
13 changes: 8 additions & 5 deletions index.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ describe('Node script test', () => {
})

it('should create a directory with expected files', () => {
const command = `REPLICATE_API_TOKEN=test_token node index.mjs ${directoryName} --run-after-setup=false`
const command = `REPLICATE_API_TOKEN=r8_test_token node index.mjs ${directoryName} --run-after-setup=false`

// set stdio to 'inherit' to see script output in test output
execSync(command, { stdio: 'ignore', env: process.env })

// Check if the directory exists
Expand All @@ -40,7 +41,7 @@ describe('Node script test', () => {
const envFile = path.join(directoryName, '.env')
expect(fileExists(envFile)).toBe(true)
const envFileContents = fs.readFileSync(envFile, 'utf8')
expect(envFileContents).toBe('REPLICATE_API_TOKEN=test_token')
expect(envFileContents).toBe('REPLICATE_API_TOKEN=r8_test_token')

// Check if .gitignore exists in the directory
const gitignoreFile = path.join(directoryName, '.gitignore')
Expand All @@ -50,8 +51,9 @@ describe('Node script test', () => {
})

it('handles basic `model` argument in the form {owner}/{model}', () => {
const command = `REPLICATE_API_TOKEN=test_token node index.mjs ${directoryName} --model=yorickvp/llava-13b --run-after-setup=false`
const command = `REPLICATE_API_TOKEN=r8_test_token node index.mjs ${directoryName} --model=yorickvp/llava-13b --run-after-setup=false`

// set stdio to 'inherit' to see script output in test output
execSync(command, { stdio: 'ignore', env: process.env })

// Check if the directory exists
Expand All @@ -63,12 +65,13 @@ describe('Node script test', () => {

// Check if index.js contains the correct model name
const indexFileContents = fs.readFileSync(indexFile, 'utf8')
expect(indexFileContents).toMatch(/yorickvp\/llava-13b:[a-zA-Z0-9]{40}/)
expect(indexFileContents).toMatch(/yorickvp\/llava-13b:[a-zA-Z0-9]{64}/)
})

it('handles a `model` argument in the form {owner}/{model}:{version}', () => {
const command = `REPLICATE_API_TOKEN=test_token node index.mjs ${directoryName} --model=yorickvp/llava-13b:2cfef05a8e8e648f6e92ddb53fa21a81c04ab2c4f1390a6528cc4e331d608df8 --run-after-setup=false`
const command = `REPLICATE_API_TOKEN=r8_test_token node index.mjs ${directoryName} --model=yorickvp/llava-13b:2cfef05a8e8e648f6e92ddb53fa21a81c04ab2c4f1390a6528cc4e331d608df8 --run-after-setup=false`

// set stdio to 'inherit' to see script output in test output
execSync(command, { stdio: 'ignore', env: process.env })

// Check if the directory exists
Expand Down
21 changes: 15 additions & 6 deletions lib/models.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import models from 'all-the-public-replicate-models'
import Replicate from 'replicate'
import fs from 'fs'
import path from 'path'

export function getModelInputs (model) {
return model.default_example.input
Expand All @@ -10,13 +12,20 @@ export function getModelNameWithVersion (model) {

export async function getModel (fullModelName) {
// Extract owner and model name, omitting the version if it's present
const [owner, modelName] = fullModelName.split(':')[0].split('/')
const [owner, name] = fullModelName.split(':')[0].split('/')

const model = models.find(model => model.owner === owner && model.name === modelName)

if (!model) {
throw new Error(`Model "${fullModelName}" not found`)
if (process.env.REPLICATE_API_TOKEN === 'r8_test_token') {
const filePath = path.join(process.cwd(), 'test', 'fixtures', owner, `${name}.json`)
const fileContents = fs.readFileSync(filePath, 'utf8')
const loadedModel = JSON.parse(fileContents)
return loadedModel
}

// Instantiate a Replicate client on the fly instead of at the top of this module,
// as the API token may have been user-provided and added to the process env AFTER this script's import time.
const replicate = new Replicate({ auth: process.env.REPLICATE_API_TOKEN })

const model = await replicate.models.get(owner, name)

return model
}
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
"node": ">=18"
},
"dependencies": {
"all-the-public-replicate-models": "^1.104.0",
"json5": "^2.2.3",
"minimist": "^1.2.8",
"open": "^10.0.3",
"readline-sync": "^1.4.10"
"readline-sync": "^1.4.10",
"replicate": "^0.25.2"
},
"devDependencies": {
"standard": "^17.1.0",
Expand Down
Loading

0 comments on commit 32f14f1

Please sign in to comment.