Skip to content

Commit

Permalink
Fix Types in Build (#585)
Browse files Browse the repository at this point in the history
Fixes #584
  • Loading branch information
abdel-17 authored Mar 27, 2024
1 parent b89956d commit 4f70dba
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 3 deletions.
204 changes: 203 additions & 1 deletion packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,210 @@
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

/**
* Inputs for Text Generation inference
*/
export interface TextGenerationInput {
/**
* The text to initialize generation with
*/
inputs: string;
/**
* Additional inference parameters
*/
parameters?: TextGenerationParameters;
/**
* Whether to stream output tokens
*/
stream?: boolean;
[property: string]: unknown;
}

/**
* Additional inference parameters
*
* Additional inference parameters for Text Generation
*/
export interface TextGenerationParameters {
/**
* The number of sampling queries to run. Only the best one (in terms of total logprob) will
* be returned.
*/
best_of?: number;
/**
* Whether or not to output decoder input details
*/
decoder_input_details?: boolean;
/**
* Whether or not to output details
*/
details?: boolean;
/**
* Whether to use logits sampling instead of greedy decoding when generating new tokens.
*/
do_sample?: boolean;
/**
* The maximum number of tokens to generate.
*/
max_new_tokens?: number;
/**
* The parameter for repetition penalty. A value of 1.0 means no penalty. See [this
* paper](https://hf.co/papers/1909.05858) for more details.
*/
repetition_penalty?: number;
/**
* Whether to prepend the prompt to the generated text.
*/
return_full_text?: boolean;
/**
* The random sampling seed.
*/
seed?: number;
/**
* Stop generating tokens if a member of `stop_sequences` is generated.
*/
stop_sequences?: string[];
/**
* The value used to modulate the logits distribution.
*/
temperature?: number;
/**
* The number of highest probability vocabulary tokens to keep for top-k-filtering.
*/
top_k?: number;
/**
* If set to < 1, only the smallest set of most probable tokens with probabilities that add
* up to `top_p` or higher are kept for generation.
*/
top_p?: number;
/**
* Truncate input tokens to the given size.
*/
truncate?: number;
/**
* Typical Decoding mass. See [Typical Decoding for Natural Language
* Generation](https://hf.co/papers/2202.00666) for more information
*/
typical_p?: number;
/**
* Watermarking with [A Watermark for Large Language Models](https://hf.co/papers/2301.10226)
*/
watermark?: boolean;
[property: string]: unknown;
}

/**
* Outputs for Text Generation inference
*/
export interface TextGenerationOutput {
/**
* When enabled, details about the generation
*/
details?: TextGenerationOutputDetails;
/**
* The generated text
*/
generated_text: string;
[property: string]: unknown;
}

/**
* When enabled, details about the generation
*/
export interface TextGenerationOutputDetails {
/**
* Details about additional sequences when best_of is provided
*/
best_of_sequences?: TextGenerationOutputSequenceDetails[];
/**
* The reason why the generation was stopped.
*/
finish_reason: TextGenerationFinishReason;
/**
* The number of generated tokens
*/
generated_tokens: number;
prefill: TextGenerationPrefillToken[];
/**
* The random seed used for generation
*/
seed?: number;
/**
* The generated tokens and associated details
*/
tokens: TextGenerationOutputToken[];
/**
* Most likely tokens
*/
top_tokens?: Array<TextGenerationOutputToken[]>;
[property: string]: unknown;
}

export interface TextGenerationOutputSequenceDetails {
finish_reason: TextGenerationFinishReason;
/**
* The generated text
*/
generated_text: string;
/**
* The number of generated tokens
*/
generated_tokens: number;
prefill: TextGenerationPrefillToken[];
/**
* The random seed used for generation
*/
seed?: number;
/**
* The generated tokens and associated details
*/
tokens: TextGenerationOutputToken[];
/**
* Most likely tokens
*/
top_tokens?: Array<TextGenerationOutputToken[]>;
[property: string]: unknown;
}

export interface TextGenerationPrefillToken {
id: number;
logprob: number;
/**
* The text associated with that token
*/
text: string;
[property: string]: unknown;
}

/**
* Generated token.
*/
export interface TextGenerationOutputToken {
id: number;
logprob?: number;
/**
* Whether or not that token is a special one
*/
special: boolean;
/**
* The text associated with that token
*/
text: string;
[property: string]: unknown;
}

/**
* The reason why the generation was stopped.
*
* length: The generated sequence reached the maximum allowed length
*
* eos_token: The model generated an end-of-sentence (EOS) token
*
* stop_sequence: One of the sequence in stop_sequences was generated
*/
export type TextGenerationFinishReason = "length" | "eos_token" | "stop_sequence";

/**
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
*/
Expand Down
3 changes: 1 addition & 2 deletions packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { BaseArgs, Options } from "../../types";
import { streamingRequest } from "../custom/streamingRequest";

import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";
import type { TextGenerationInput } from "./textGeneration";

export interface TextGenerationStreamToken {
/** Token ID from the model tokenizer */
Expand Down

0 comments on commit 4f70dba

Please sign in to comment.